CodeReview/backend/app/services/agent/tools/rag_tool.py

294 lines
10 KiB
Python
Raw Normal View History

"""
RAG 检索工具
支持语义检索代码
"""
from typing import Optional, List
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
from app.services.rag import CodeRetriever
class RAGQueryInput(BaseModel):
"""RAG 查询输入参数"""
query: str = Field(description="搜索查询,描述你要找的代码功能或特征")
top_k: int = Field(default=10, description="返回结果数量")
file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径")
language: Optional[str] = Field(default=None, description="限定编程语言")
class RAGQueryTool(AgentTool):
"""
RAG 代码检索工具
使用语义搜索在代码库中查找相关代码
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "rag_query"
@property
def description(self) -> str:
return """在代码库中进行语义搜索。
使用场景:
- 查找特定功能的实现代码
- 查找调用某个函数的代码
- 查找处理用户输入的代码
- 查找数据库操作相关代码
- 查找认证/授权相关代码
输入:
- query: 描述你要查找的代码例如 "处理用户登录的函数""SQL查询执行""文件上传处理"
- top_k: 返回结果数量默认10
- file_path: 可选限定在某个文件中搜索
- language: 可选限定编程语言
输出: 相关的代码片段列表包含文件路径行号代码内容和相似度分数"""
@property
def args_schema(self):
return RAGQueryInput
async def _execute(
self,
query: str,
top_k: int = 10,
file_path: Optional[str] = None,
language: Optional[str] = None,
**kwargs
) -> ToolResult:
"""执行 RAG 检索"""
try:
results = await self.retriever.retrieve(
query=query,
top_k=top_k,
filter_file_path=file_path,
filter_language=language,
)
if not results:
return ToolResult(
success=True,
data="没有找到相关代码",
metadata={"query": query, "results_count": 0}
)
# 格式化输出
output_parts = [f"找到 {len(results)} 个相关代码片段:\n"]
for i, result in enumerate(results):
output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---")
output_parts.append(f"文件: {result.file_path}")
output_parts.append(f"行号: {result.line_start}-{result.line_end}")
if result.name:
output_parts.append(f"名称: {result.name}")
if result.security_indicators:
output_parts.append(f"安全指标: {', '.join(result.security_indicators)}")
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"query": query,
"results_count": len(results),
"results": [r.to_dict() for r in results],
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"RAG 检索失败: {str(e)}",
)
class SecurityCodeSearchInput(BaseModel):
"""安全代码搜索输入"""
vulnerability_type: str = Field(
description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret"
)
top_k: int = Field(default=20, description="返回结果数量")
class SecurityCodeSearchTool(AgentTool):
"""
安全相关代码搜索工具
专门用于查找可能存在安全漏洞的代码
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "security_code_search"
@property
def description(self) -> str:
return """搜索可能存在安全漏洞的代码。
专门针对特定漏洞类型进行搜索
支持的漏洞类型:
- sql_injection: SQL 注入
- xss: 跨站脚本
- command_injection: 命令注入
- path_traversal: 路径遍历
- ssrf: 服务端请求伪造
- deserialization: 不安全的反序列化
- auth_bypass: 认证绕过
- hardcoded_secret: 硬编码密钥"""
@property
def args_schema(self):
return SecurityCodeSearchInput
async def _execute(
self,
vulnerability_type: str,
top_k: int = 20,
**kwargs
) -> ToolResult:
"""执行安全代码搜索"""
try:
results = await self.retriever.retrieve_security_related(
vulnerability_type=vulnerability_type,
top_k=top_k,
)
if not results:
return ToolResult(
success=True,
data=f"没有找到与 {vulnerability_type} 相关的代码",
metadata={"vulnerability_type": vulnerability_type, "results_count": 0}
)
# 格式化输出
output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"]
for i, result in enumerate(results):
output_parts.append(f"\n--- 可疑代码 {i+1} ---")
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
if result.security_indicators:
output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}")
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"vulnerability_type": vulnerability_type,
"results_count": len(results),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"安全代码搜索失败: {str(e)}",
)
class FunctionContextInput(BaseModel):
"""函数上下文搜索输入"""
function_name: str = Field(description="函数名称")
file_path: Optional[str] = Field(default=None, description="文件路径")
include_callers: bool = Field(default=True, description="是否包含调用者")
include_callees: bool = Field(default=True, description="是否包含被调用的函数")
class FunctionContextTool(AgentTool):
"""
函数上下文搜索工具
查找函数的定义调用者和被调用者
"""
def __init__(self, retriever: CodeRetriever):
super().__init__()
self.retriever = retriever
@property
def name(self) -> str:
return "function_context"
@property
def description(self) -> str:
return """查找函数的上下文信息,包括定义、调用者和被调用的函数。
用于追踪数据流和理解函数的使用方式
输入:
- function_name: 要查找的函数名
- file_path: 可选限定文件路径
- include_callers: 是否查找调用此函数的代码
- include_callees: 是否查找此函数调用的其他函数"""
@property
def args_schema(self):
return FunctionContextInput
async def _execute(
self,
function_name: str,
file_path: Optional[str] = None,
include_callers: bool = True,
include_callees: bool = True,
**kwargs
) -> ToolResult:
"""执行函数上下文搜索"""
try:
context = await self.retriever.retrieve_function_context(
function_name=function_name,
file_path=file_path,
include_callers=include_callers,
include_callees=include_callees,
)
output_parts = [f"函数 '{function_name}' 的上下文分析:\n"]
# 函数定义
if context["definition"]:
output_parts.append("### 函数定义:")
for result in context["definition"]:
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
output_parts.append(f"```{result.language}\n{result.content}\n```")
else:
output_parts.append("未找到函数定义")
# 调用者
if context["callers"]:
output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):")
for result in context["callers"][:5]:
output_parts.append(f"- {result.file_path}:{result.line_start}")
output_parts.append(f"```{result.language}\n{result.content[:500]}\n```")
# 被调用者
if context["callees"]:
output_parts.append(f"\n### 此函数调用的其他函数:")
for result in context["callees"][:5]:
if result.name:
output_parts.append(f"- {result.name} ({result.file_path})")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"function_name": function_name,
"definition_count": len(context["definition"]),
"callers_count": len(context["callers"]),
"callees_count": len(context["callees"]),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"函数上下文搜索失败: {str(e)}",
)