294 lines
10 KiB
Python
294 lines
10 KiB
Python
|
|
"""
|
|||
|
|
RAG 检索工具
|
|||
|
|
支持语义检索代码
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Optional, List
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
from .base import AgentTool, ToolResult
|
|||
|
|
from app.services.rag import CodeRetriever
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGQueryInput(BaseModel):
|
|||
|
|
"""RAG 查询输入参数"""
|
|||
|
|
query: str = Field(description="搜索查询,描述你要找的代码功能或特征")
|
|||
|
|
top_k: int = Field(default=10, description="返回结果数量")
|
|||
|
|
file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径")
|
|||
|
|
language: Optional[str] = Field(default=None, description="限定编程语言")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGQueryTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
RAG 代码检索工具
|
|||
|
|
使用语义搜索在代码库中查找相关代码
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, retriever: CodeRetriever):
|
|||
|
|
super().__init__()
|
|||
|
|
self.retriever = retriever
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "rag_query"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return """在代码库中进行语义搜索。
|
|||
|
|
使用场景:
|
|||
|
|
- 查找特定功能的实现代码
|
|||
|
|
- 查找调用某个函数的代码
|
|||
|
|
- 查找处理用户输入的代码
|
|||
|
|
- 查找数据库操作相关代码
|
|||
|
|
- 查找认证/授权相关代码
|
|||
|
|
|
|||
|
|
输入:
|
|||
|
|
- query: 描述你要查找的代码,例如 "处理用户登录的函数"、"SQL查询执行"、"文件上传处理"
|
|||
|
|
- top_k: 返回结果数量(默认10)
|
|||
|
|
- file_path: 可选,限定在某个文件中搜索
|
|||
|
|
- language: 可选,限定编程语言
|
|||
|
|
|
|||
|
|
输出: 相关的代码片段列表,包含文件路径、行号、代码内容和相似度分数"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self):
|
|||
|
|
return RAGQueryInput
|
|||
|
|
|
|||
|
|
async def _execute(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
top_k: int = 10,
|
|||
|
|
file_path: Optional[str] = None,
|
|||
|
|
language: Optional[str] = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> ToolResult:
|
|||
|
|
"""执行 RAG 检索"""
|
|||
|
|
try:
|
|||
|
|
results = await self.retriever.retrieve(
|
|||
|
|
query=query,
|
|||
|
|
top_k=top_k,
|
|||
|
|
filter_file_path=file_path,
|
|||
|
|
filter_language=language,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not results:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="没有找到相关代码",
|
|||
|
|
metadata={"query": query, "results_count": 0}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 格式化输出
|
|||
|
|
output_parts = [f"找到 {len(results)} 个相关代码片段:\n"]
|
|||
|
|
|
|||
|
|
for i, result in enumerate(results):
|
|||
|
|
output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---")
|
|||
|
|
output_parts.append(f"文件: {result.file_path}")
|
|||
|
|
output_parts.append(f"行号: {result.line_start}-{result.line_end}")
|
|||
|
|
if result.name:
|
|||
|
|
output_parts.append(f"名称: {result.name}")
|
|||
|
|
if result.security_indicators:
|
|||
|
|
output_parts.append(f"安全指标: {', '.join(result.security_indicators)}")
|
|||
|
|
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="\n".join(output_parts),
|
|||
|
|
metadata={
|
|||
|
|
"query": query,
|
|||
|
|
"results_count": len(results),
|
|||
|
|
"results": [r.to_dict() for r in results],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"RAG 检索失败: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityCodeSearchInput(BaseModel):
|
|||
|
|
"""安全代码搜索输入"""
|
|||
|
|
vulnerability_type: str = Field(
|
|||
|
|
description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret"
|
|||
|
|
)
|
|||
|
|
top_k: int = Field(default=20, description="返回结果数量")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityCodeSearchTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
安全相关代码搜索工具
|
|||
|
|
专门用于查找可能存在安全漏洞的代码
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, retriever: CodeRetriever):
|
|||
|
|
super().__init__()
|
|||
|
|
self.retriever = retriever
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "security_code_search"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return """搜索可能存在安全漏洞的代码。
|
|||
|
|
专门针对特定漏洞类型进行搜索。
|
|||
|
|
|
|||
|
|
支持的漏洞类型:
|
|||
|
|
- sql_injection: SQL 注入
|
|||
|
|
- xss: 跨站脚本
|
|||
|
|
- command_injection: 命令注入
|
|||
|
|
- path_traversal: 路径遍历
|
|||
|
|
- ssrf: 服务端请求伪造
|
|||
|
|
- deserialization: 不安全的反序列化
|
|||
|
|
- auth_bypass: 认证绕过
|
|||
|
|
- hardcoded_secret: 硬编码密钥"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self):
|
|||
|
|
return SecurityCodeSearchInput
|
|||
|
|
|
|||
|
|
async def _execute(
|
|||
|
|
self,
|
|||
|
|
vulnerability_type: str,
|
|||
|
|
top_k: int = 20,
|
|||
|
|
**kwargs
|
|||
|
|
) -> ToolResult:
|
|||
|
|
"""执行安全代码搜索"""
|
|||
|
|
try:
|
|||
|
|
results = await self.retriever.retrieve_security_related(
|
|||
|
|
vulnerability_type=vulnerability_type,
|
|||
|
|
top_k=top_k,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not results:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data=f"没有找到与 {vulnerability_type} 相关的代码",
|
|||
|
|
metadata={"vulnerability_type": vulnerability_type, "results_count": 0}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 格式化输出
|
|||
|
|
output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"]
|
|||
|
|
|
|||
|
|
for i, result in enumerate(results):
|
|||
|
|
output_parts.append(f"\n--- 可疑代码 {i+1} ---")
|
|||
|
|
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
|
|||
|
|
if result.security_indicators:
|
|||
|
|
output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}")
|
|||
|
|
output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```")
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="\n".join(output_parts),
|
|||
|
|
metadata={
|
|||
|
|
"vulnerability_type": vulnerability_type,
|
|||
|
|
"results_count": len(results),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"安全代码搜索失败: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FunctionContextInput(BaseModel):
|
|||
|
|
"""函数上下文搜索输入"""
|
|||
|
|
function_name: str = Field(description="函数名称")
|
|||
|
|
file_path: Optional[str] = Field(default=None, description="文件路径")
|
|||
|
|
include_callers: bool = Field(default=True, description="是否包含调用者")
|
|||
|
|
include_callees: bool = Field(default=True, description="是否包含被调用的函数")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FunctionContextTool(AgentTool):
|
|||
|
|
"""
|
|||
|
|
函数上下文搜索工具
|
|||
|
|
查找函数的定义、调用者和被调用者
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, retriever: CodeRetriever):
|
|||
|
|
super().__init__()
|
|||
|
|
self.retriever = retriever
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def name(self) -> str:
|
|||
|
|
return "function_context"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def description(self) -> str:
|
|||
|
|
return """查找函数的上下文信息,包括定义、调用者和被调用的函数。
|
|||
|
|
用于追踪数据流和理解函数的使用方式。
|
|||
|
|
|
|||
|
|
输入:
|
|||
|
|
- function_name: 要查找的函数名
|
|||
|
|
- file_path: 可选,限定文件路径
|
|||
|
|
- include_callers: 是否查找调用此函数的代码
|
|||
|
|
- include_callees: 是否查找此函数调用的其他函数"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self):
|
|||
|
|
return FunctionContextInput
|
|||
|
|
|
|||
|
|
async def _execute(
|
|||
|
|
self,
|
|||
|
|
function_name: str,
|
|||
|
|
file_path: Optional[str] = None,
|
|||
|
|
include_callers: bool = True,
|
|||
|
|
include_callees: bool = True,
|
|||
|
|
**kwargs
|
|||
|
|
) -> ToolResult:
|
|||
|
|
"""执行函数上下文搜索"""
|
|||
|
|
try:
|
|||
|
|
context = await self.retriever.retrieve_function_context(
|
|||
|
|
function_name=function_name,
|
|||
|
|
file_path=file_path,
|
|||
|
|
include_callers=include_callers,
|
|||
|
|
include_callees=include_callees,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
output_parts = [f"函数 '{function_name}' 的上下文分析:\n"]
|
|||
|
|
|
|||
|
|
# 函数定义
|
|||
|
|
if context["definition"]:
|
|||
|
|
output_parts.append("### 函数定义:")
|
|||
|
|
for result in context["definition"]:
|
|||
|
|
output_parts.append(f"文件: {result.file_path}:{result.line_start}")
|
|||
|
|
output_parts.append(f"```{result.language}\n{result.content}\n```")
|
|||
|
|
else:
|
|||
|
|
output_parts.append("未找到函数定义")
|
|||
|
|
|
|||
|
|
# 调用者
|
|||
|
|
if context["callers"]:
|
|||
|
|
output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):")
|
|||
|
|
for result in context["callers"][:5]:
|
|||
|
|
output_parts.append(f"- {result.file_path}:{result.line_start}")
|
|||
|
|
output_parts.append(f"```{result.language}\n{result.content[:500]}\n```")
|
|||
|
|
|
|||
|
|
# 被调用者
|
|||
|
|
if context["callees"]:
|
|||
|
|
output_parts.append(f"\n### 此函数调用的其他函数:")
|
|||
|
|
for result in context["callees"][:5]:
|
|||
|
|
if result.name:
|
|||
|
|
output_parts.append(f"- {result.name} ({result.file_path})")
|
|||
|
|
|
|||
|
|
return ToolResult(
|
|||
|
|
success=True,
|
|||
|
|
data="\n".join(output_parts),
|
|||
|
|
metadata={
|
|||
|
|
"function_name": function_name,
|
|||
|
|
"definition_count": len(context["definition"]),
|
|||
|
|
"callers_count": len(context["callers"]),
|
|||
|
|
"callees_count": len(context["callees"]),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=f"函数上下文搜索失败: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|