""" RAG 检索工具 支持语义检索代码 """ from typing import Optional, List from pydantic import BaseModel, Field from .base import AgentTool, ToolResult from app.services.rag import CodeRetriever class RAGQueryInput(BaseModel): """RAG 查询输入参数""" query: str = Field(description="搜索查询,描述你要找的代码功能或特征") top_k: int = Field(default=10, description="返回结果数量") file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径") language: Optional[str] = Field(default=None, description="限定编程语言") class RAGQueryTool(AgentTool): """ RAG 代码检索工具 使用语义搜索在代码库中查找相关代码 """ def __init__(self, retriever: CodeRetriever): super().__init__() self.retriever = retriever @property def name(self) -> str: return "rag_query" @property def description(self) -> str: return """在代码库中进行语义搜索。 使用场景: - 查找特定功能的实现代码 - 查找调用某个函数的代码 - 查找处理用户输入的代码 - 查找数据库操作相关代码 - 查找认证/授权相关代码 输入: - query: 描述你要查找的代码,例如 "处理用户登录的函数"、"SQL查询执行"、"文件上传处理" - top_k: 返回结果数量(默认10) - file_path: 可选,限定在某个文件中搜索 - language: 可选,限定编程语言 输出: 相关的代码片段列表,包含文件路径、行号、代码内容和相似度分数""" @property def args_schema(self): return RAGQueryInput async def _execute( self, query: str, top_k: int = 10, file_path: Optional[str] = None, language: Optional[str] = None, **kwargs ) -> ToolResult: """执行 RAG 检索""" try: results = await self.retriever.retrieve( query=query, top_k=top_k, filter_file_path=file_path, filter_language=language, ) if not results: return ToolResult( success=True, data="没有找到相关代码", metadata={"query": query, "results_count": 0} ) # 格式化输出 output_parts = [f"找到 {len(results)} 个相关代码片段:\n"] for i, result in enumerate(results): output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---") output_parts.append(f"文件: {result.file_path}") output_parts.append(f"行号: {result.line_start}-{result.line_end}") if result.name: output_parts.append(f"名称: {result.name}") if result.security_indicators: output_parts.append(f"安全指标: {', '.join(result.security_indicators)}") output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "query": query, "results_count": len(results), "results": [r.to_dict() for r in results], } ) except Exception as e: return ToolResult( success=False, error=f"RAG 检索失败: {str(e)}", ) class SecurityCodeSearchInput(BaseModel): """安全代码搜索输入""" vulnerability_type: str = Field( description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret" ) top_k: int = Field(default=20, description="返回结果数量") class SecurityCodeSearchTool(AgentTool): """ 安全相关代码搜索工具 专门用于查找可能存在安全漏洞的代码 """ def __init__(self, retriever: CodeRetriever): super().__init__() self.retriever = retriever @property def name(self) -> str: return "security_code_search" @property def description(self) -> str: return """搜索可能存在安全漏洞的代码。 专门针对特定漏洞类型进行搜索。 支持的漏洞类型: - sql_injection: SQL 注入 - xss: 跨站脚本 - command_injection: 命令注入 - path_traversal: 路径遍历 - ssrf: 服务端请求伪造 - deserialization: 不安全的反序列化 - auth_bypass: 认证绕过 - hardcoded_secret: 硬编码密钥""" @property def args_schema(self): return SecurityCodeSearchInput async def _execute( self, vulnerability_type: str, top_k: int = 20, **kwargs ) -> ToolResult: """执行安全代码搜索""" try: results = await self.retriever.retrieve_security_related( vulnerability_type=vulnerability_type, top_k=top_k, ) if not results: return ToolResult( success=True, data=f"没有找到与 {vulnerability_type} 相关的代码", metadata={"vulnerability_type": vulnerability_type, "results_count": 0} ) # 格式化输出 output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"] for i, result in enumerate(results): output_parts.append(f"\n--- 可疑代码 {i+1} ---") output_parts.append(f"文件: {result.file_path}:{result.line_start}") if result.security_indicators: output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}") output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "vulnerability_type": vulnerability_type, "results_count": len(results), } ) except Exception as e: return ToolResult( success=False, error=f"安全代码搜索失败: {str(e)}", ) class FunctionContextInput(BaseModel): """函数上下文搜索输入""" function_name: str = Field(description="函数名称") file_path: Optional[str] = Field(default=None, description="文件路径") include_callers: bool = Field(default=True, description="是否包含调用者") include_callees: bool = Field(default=True, description="是否包含被调用的函数") class FunctionContextTool(AgentTool): """ 函数上下文搜索工具 查找函数的定义、调用者和被调用者 """ def __init__(self, retriever: CodeRetriever): super().__init__() self.retriever = retriever @property def name(self) -> str: return "function_context" @property def description(self) -> str: return """查找函数的上下文信息,包括定义、调用者和被调用的函数。 用于追踪数据流和理解函数的使用方式。 输入: - function_name: 要查找的函数名 - file_path: 可选,限定文件路径 - include_callers: 是否查找调用此函数的代码 - include_callees: 是否查找此函数调用的其他函数""" @property def args_schema(self): return FunctionContextInput async def _execute( self, function_name: str, file_path: Optional[str] = None, include_callers: bool = True, include_callees: bool = True, **kwargs ) -> ToolResult: """执行函数上下文搜索""" try: context = await self.retriever.retrieve_function_context( function_name=function_name, file_path=file_path, include_callers=include_callers, include_callees=include_callees, ) output_parts = [f"函数 '{function_name}' 的上下文分析:\n"] # 函数定义 if context["definition"]: output_parts.append("### 函数定义:") for result in context["definition"]: output_parts.append(f"文件: {result.file_path}:{result.line_start}") output_parts.append(f"```{result.language}\n{result.content}\n```") else: output_parts.append("未找到函数定义") # 调用者 if context["callers"]: output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):") for result in context["callers"][:5]: output_parts.append(f"- {result.file_path}:{result.line_start}") output_parts.append(f"```{result.language}\n{result.content[:500]}\n```") # 被调用者 if context["callees"]: output_parts.append(f"\n### 此函数调用的其他函数:") for result in context["callees"][:5]: if result.name: output_parts.append(f"- {result.name} ({result.file_path})") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "function_name": function_name, "definition_count": len(context["definition"]), "callers_count": len(context["callers"]), "callees_count": len(context["callees"]), } ) except Exception as e: return ToolResult( success=False, error=f"函数上下文搜索失败: {str(e)}", )