From 5f074038501bd7cfba5c406782048872f9fda3da Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Tue, 16 Dec 2025 13:57:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(agent):=20=E5=A2=9E=E5=BC=BA=20RAG=20?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=92=8C=E5=B7=A5=E5=85=B7=E9=9B=86=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 扩展嵌入模型配置选项,支持独立 API Key 和 Base URL - 重构 RAG 初始化逻辑,支持用户自定义嵌入配置 - 新增语义搜索工具并集成到 Recon 和 Analysis Agent - 完善系统提示,明确不同代码搜索工具的使用场景 --- backend/app/api/v1/endpoints/agent_tasks.py | 99 +++++++++++++++++-- backend/app/core/config.py | 8 +- backend/app/services/agent/graph/runner.py | 54 ++++++---- .../services/agent/prompts/system_prompts.py | 25 ++++- 4 files changed, 159 insertions(+), 27 deletions(-) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 4bbec44..420c4f6 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -296,12 +296,13 @@ async def _execute_agent_task(task_id: str): # 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager tools = await _initialize_tools( - project_root, - llm_service, + project_root, + llm_service, user_config, sandbox_manager=sandbox_manager, exclude_patterns=task.exclude_patterns, target_files=task.target_files, + project_id=str(project.id), # 🔥 传递 project_id 用于 RAG ) # 创建子 Agent @@ -552,15 +553,16 @@ async def _get_user_config(db: AsyncSession, user_id: Optional[str]) -> Optional async def _initialize_tools( - project_root: str, - llm_service, + project_root: str, + llm_service, user_config: Optional[Dict[str, Any]], sandbox_manager: Any, # 传递预初始化的 SandboxManager exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, + project_id: Optional[str] = None, # 🔥 用于 RAG collection_name ) -> Dict[str, Dict[str, Any]]: """初始化工具集 - + Args: project_root: 项目根目录 llm_service: LLM 服务 @@ -568,6 +570,7 @@ async def _initialize_tools( sandbox_manager: 沙箱管理器 exclude_patterns: 排除模式列表 target_files: 目标文件列表 + project_id: 项目 ID(用于 RAG collection_name) """ from app.services.agent.tools import ( FileReadTool, FileSearchTool, ListFilesTool, @@ -577,12 +580,82 @@ async def _initialize_tools( ThinkTool, ReflectTool, CreateVulnerabilityReportTool, VulnerabilityValidationTool, + # 🔥 RAG 工具 + RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool, ) from app.services.agent.knowledge import ( SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, ) - + # 🔥 RAG 相关导入 + from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService + from app.core.config import settings + + # ============ 🔥 初始化 RAG 系统 ============ + retriever = None + try: + # 从用户配置中获取 embedding 配置 + user_llm_config = (user_config or {}).get('llmConfig', {}) + user_other_config = (user_config or {}).get('otherConfig', {}) + user_embedding_config = user_other_config.get('embedding_config', {}) + + # Embedding Provider 优先级:用户嵌入配置 > 环境变量 + embedding_provider = ( + user_embedding_config.get('provider') or + getattr(settings, 'EMBEDDING_PROVIDER', 'openai') + ) + + # Embedding Model 优先级:用户嵌入配置 > 环境变量 + embedding_model = ( + user_embedding_config.get('model') or + getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') + ) + + # API Key 优先级:用户嵌入配置 > 环境变量 EMBEDDING_API_KEY > 用户 LLM 配置 > 环境变量 LLM_API_KEY + # 注意:API Key 可以共享,因为很多用户使用同一个 OpenAI Key 做 LLM 和 Embedding + embedding_api_key = ( + user_embedding_config.get('api_key') or + getattr(settings, 'EMBEDDING_API_KEY', None) or + user_llm_config.get('llmApiKey') or + getattr(settings, 'LLM_API_KEY', '') or + '' + ) + + # Base URL 优先级:用户嵌入配置 > 环境变量 EMBEDDING_BASE_URL > None(使用提供商默认地址) + # 🔥 重要:Base URL 不应该回退到 LLM 的 base_url,因为 Embedding 和 LLM 可能使用完全不同的服务 + # 例如:LLM 使用 SiliconFlow,但 Embedding 使用 HuggingFace + embedding_base_url = ( + user_embedding_config.get('base_url') or + getattr(settings, 'EMBEDDING_BASE_URL', None) or + None + ) + + logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}, base_url={embedding_base_url or '(使用默认)'}") + + # 创建 Embedding 服务 + embedding_service = EmbeddingService( + provider=embedding_provider, + model=embedding_model, + api_key=embedding_api_key, + base_url=embedding_base_url, + ) + + # 创建 collection_name(基于 project_id) + collection_name = f"project_{project_id}" if project_id else "default_project" + + # 创建 CodeRetriever(用于搜索) + retriever = CodeRetriever( + collection_name=collection_name, + embedding_service=embedding_service, + persist_directory=settings.VECTOR_DB_PATH, + ) + + logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}") + + except Exception as e: + logger.warning(f"⚠️ RAG 系统初始化失败: {e}") + retriever = None + # 基础工具 - 传递排除模式和目标文件 base_tools = { "read_file": FileReadTool(project_root, exclude_patterns, target_files), @@ -604,6 +677,11 @@ async def _initialize_tools( "trufflehog_scan": TruffleHogTool(project_root, sandbox_manager), "osv_scan": OSVScannerTool(project_root, sandbox_manager), } + + # 🔥 注册 RAG 工具到 Recon Agent + if retriever: + recon_tools["rag_query"] = RAGQueryTool(retriever) + logger.info("✅ RAG 工具 (rag_query) 已注册到 Recon Agent") # Analysis 工具 # 🔥 导入智能扫描工具 @@ -630,6 +708,15 @@ async def _initialize_tools( "query_security_knowledge": SecurityKnowledgeQueryTool(), "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), } + + # 🔥 注册 RAG 工具到 Analysis Agent + if retriever: + analysis_tools["rag_query"] = RAGQueryTool(retriever) + analysis_tools["security_search"] = SecurityCodeSearchTool(retriever) + analysis_tools["function_context"] = FunctionContextTool(retriever) + logger.info("✅ RAG 工具 (rag_query, security_search, function_context) 已注册到 Analysis Agent") + else: + logger.warning("⚠️ RAG 未初始化,rag_query/security_search/function_context 工具不可用") # Verification 工具 # 🔥 导入沙箱工具 diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 43a71bb..43d6385 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -78,10 +78,12 @@ class Settings(BaseSettings): OUTPUT_LANGUAGE: str = "zh-CN" # ============ Agent 模块配置 ============ - - # 嵌入模型配置 - EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm + + # 嵌入模型配置(独立于 LLM 配置) + EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina EMBEDDING_MODEL: str = "text-embedding-3-small" + EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key(留空则使用 LLM_API_KEY) + EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL(留空使用提供商默认地址) # 向量数据库配置 VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录 diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index 31f41cd..031cad3 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -142,30 +142,45 @@ class AgentRunner: async def _initialize_rag(self): """初始化 RAG 系统""" await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...") - + try: - # 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key) - # 优先级:用户配置 > 环境变量 + # 🔥 从用户配置中获取配置 + # 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 user_llm_config = self.user_config.get('llmConfig', {}) - - # 获取 Embedding 配置(优先使用用户配置的 LLM API Key) - embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai') - embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') - - # 🔥 API Key 优先级:用户配置 > 环境变量 + user_other_config = self.user_config.get('otherConfig', {}) + user_embedding_config = user_other_config.get('embedding_config', {}) + + # 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量 + embedding_provider = ( + user_embedding_config.get('provider') or + getattr(settings, 'EMBEDDING_PROVIDER', 'openai') + ) + + # 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量 + embedding_model = ( + user_embedding_config.get('model') or + getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') + ) + + # 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 embedding_api_key = ( + user_embedding_config.get('api_key') or user_llm_config.get('llmApiKey') or getattr(settings, 'LLM_API_KEY', '') or '' ) - - # 🔥 Base URL 优先级:用户配置 > 环境变量 + + # 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 embedding_base_url = ( + user_embedding_config.get('base_url') or user_llm_config.get('llmBaseUrl') or getattr(settings, 'LLM_BASE_URL', None) or None ) - + + logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}") + await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}") + embedding_service = EmbeddingService( provider=embedding_provider, model=embedding_model, @@ -267,11 +282,14 @@ class AgentRunner: "safety_scan": SafetyTool(self.project_root, self.sandbox_manager), "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager), } - + # RAG 工具(Recon 用于语义搜索) if self.retriever: self.recon_tools["rag_query"] = RAGQueryTool(self.retriever) - + logger.info("✅ RAG 工具已注册到 Recon Agent") + else: + logger.warning("⚠️ RAG 未初始化,rag_query 工具不可用") + # ============ Analysis Agent 专属工具 ============ # 职责:漏洞分析、代码审计、模式匹配 self.analysis_tools = { @@ -300,11 +318,13 @@ class AgentRunner: "query_security_knowledge": SecurityKnowledgeQueryTool(), "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), } - + # RAG 工具(Analysis 用于安全相关代码搜索) if self.retriever: - self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) - self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) + self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索 + self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索 + self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文 + logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)") # ============ Verification Agent 专属工具 ============ # 职责:漏洞验证、PoC 执行、误报排除 diff --git a/backend/app/services/agent/prompts/system_prompts.py b/backend/app/services/agent/prompts/system_prompts.py index 52bfbd6..75c1c07 100644 --- a/backend/app/services/agent/prompts/system_prompts.py +++ b/backend/app/services/agent/prompts/system_prompts.py @@ -142,11 +142,28 @@ TOOL_USAGE_GUIDE = """ #### 辅助工具 | 工具 | 用途 | |------|------| +| `rag_query` | **语义搜索代码**(推荐!比 search_code 更智能,理解代码含义) | +| `security_search` | **安全相关代码搜索**(专门查找安全敏感代码) | +| `function_context` | **函数上下文搜索**(获取函数的调用关系和上下文) | | `list_files` | 了解项目结构 | | `read_file` | 读取文件内容验证发现 | -| `search_code` | 搜索相关代码 | +| `search_code` | 关键词搜索代码(精确匹配) | | `query_security_knowledge` | 查询安全知识库 | +### 🔍 代码搜索工具对比 +| 工具 | 特点 | 适用场景 | +|------|------|---------| +| `rag_query` | **语义搜索**,理解代码含义 | 查找"处理用户输入的函数"、"数据库查询逻辑" | +| `security_search` | **安全专用搜索** | 查找"SQL注入相关代码"、"认证授权代码" | +| `function_context` | **函数上下文** | 查找某函数的调用者和被调用者 | +| `search_code` | **关键词搜索**,精确匹配 | 查找特定函数名、变量名、字符串 | + +**推荐**: +1. 查找安全相关代码时优先使用 `security_search` +2. 理解函数关系时使用 `function_context` +3. 通用语义搜索使用 `rag_query` +4. 精确匹配时使用 `search_code` + ### 📋 推荐分析流程 #### 第一步:快速侦察(5%时间) @@ -156,6 +173,12 @@ Action Input: {"path": "."} ``` 了解项目结构、技术栈、入口点 +**语义搜索高风险代码(推荐!):** +``` +Action: rag_query +Action Input: {"query": "处理用户输入或执行数据库查询的函数", "top_k": 10} +``` + #### 第二步:外部工具全面扫描(60%时间)⚡重点! **根据技术栈选择对应工具,并行执行多个扫描:**