CodeReview/backend/app/services/agent/tools/file_tool.py

482 lines
16 KiB
Python
Raw Normal View History

"""
文件操作工具
读取和搜索代码文件
"""
import os
import re
import fnmatch
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
class FileReadInput(BaseModel):
"""文件读取输入"""
file_path: str = Field(description="文件路径(相对于项目根目录)")
start_line: Optional[int] = Field(default=None, description="起始行号从1开始")
end_line: Optional[int] = Field(default=None, description="结束行号")
max_lines: int = Field(default=500, description="最大返回行数")
class FileReadTool(AgentTool):
"""
文件读取工具
读取项目中的文件内容
"""
def __init__(self, project_root: str):
"""
初始化文件读取工具
Args:
project_root: 项目根目录
"""
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
return """读取项目中的文件内容。
使用场景:
- 查看完整的源代码文件
- 查看特定行范围的代码
- 获取配置文件内容
输入:
- file_path: 文件路径相对于项目根目录
- start_line: 可选起始行号
- end_line: 可选结束行号
- max_lines: 最大返回行数默认500
注意: 为避免输出过长建议指定行范围或使用 RAG 搜索定位代码"""
@property
def args_schema(self):
return FileReadInput
async def _execute(
self,
file_path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: int = 500,
**kwargs
) -> ToolResult:
"""执行文件读取"""
try:
# 安全检查:防止路径遍历
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的文件",
)
if not os.path.exists(full_path):
return ToolResult(
success=False,
error=f"文件不存在: {file_path}",
)
if not os.path.isfile(full_path):
return ToolResult(
success=False,
error=f"不是文件: {file_path}",
)
# 检查文件大小
file_size = os.path.getsize(full_path)
if file_size > 1024 * 1024: # 1MB
return ToolResult(
success=False,
error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围",
)
# 读取文件
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
total_lines = len(lines)
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行
selected_lines = lines[start_idx:end_idx]
# 添加行号
numbered_lines = []
for i, line in enumerate(selected_lines, start=start_idx + 1):
numbered_lines.append(f"{i:4d}| {line.rstrip()}")
content = '\n'.join(numbered_lines)
# 检测语言
ext = os.path.splitext(file_path)[1].lower()
language = {
".py": "python", ".js": "javascript", ".ts": "typescript",
".java": "java", ".go": "go", ".rs": "rust",
".cpp": "cpp", ".c": "c", ".cs": "csharp",
".php": "php", ".rb": "ruby", ".swift": "swift",
}.get(ext, "text")
output = f"📄 文件: {file_path}\n"
output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n"
output += f"```{language}\n{content}\n```"
if end_idx < total_lines:
output += f"\n\n... 还有 {total_lines - end_idx} 行未显示"
return ToolResult(
success=True,
data=output,
metadata={
"file_path": file_path,
"total_lines": total_lines,
"start_line": start_idx + 1,
"end_line": end_idx,
"language": language,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"读取文件失败: {str(e)}",
)
class FileSearchInput(BaseModel):
"""文件搜索输入"""
keyword: str = Field(description="搜索关键字或正则表达式")
file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js")
directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)")
case_sensitive: bool = Field(default=False, description="是否区分大小写")
max_results: int = Field(default=50, description="最大结果数")
is_regex: bool = Field(default=False, description="是否使用正则表达式")
class FileSearchTool(AgentTool):
"""
文件搜索工具
在项目中搜索包含特定内容的代码
"""
# 排除的目录
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "venv", "env",
}
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "search_code"
@property
def description(self) -> str:
return """在项目代码中搜索关键字或模式。
使用场景:
- 查找特定函数的所有调用位置
- 搜索特定的 API 使用
- 查找包含特定模式的代码
输入:
- keyword: 搜索关键字或正则表达式
- file_pattern: 可选文件名模式 *.py
- directory: 可选搜索目录
- case_sensitive: 是否区分大小写
- is_regex: 是否使用正则表达式
这是一个快速搜索工具结果包含匹配行和上下文"""
@property
def args_schema(self):
return FileSearchInput
async def _execute(
self,
keyword: str,
file_pattern: Optional[str] = None,
directory: Optional[str] = None,
case_sensitive: bool = False,
max_results: int = 50,
is_regex: bool = False,
**kwargs
) -> ToolResult:
"""执行文件搜索"""
try:
# 确定搜索目录
if directory:
search_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not search_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许搜索项目目录外的内容",
)
else:
search_dir = self.project_root
# 编译搜索模式
flags = 0 if case_sensitive else re.IGNORECASE
try:
if is_regex:
pattern = re.compile(keyword, flags)
else:
pattern = re.compile(re.escape(keyword), flags)
except re.error as e:
return ToolResult(
success=False,
error=f"无效的搜索模式: {e}",
)
results = []
files_searched = 0
# 遍历文件
for root, dirs, files in os.walk(search_dir):
# 排除目录
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS]
for filename in files:
# 检查文件模式
if file_pattern and not fnmatch.fnmatch(filename, file_pattern):
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, self.project_root)
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
files_searched += 1
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文
start = max(0, i - 1)
end = min(len(lines), i + 2)
context_lines = []
for j in range(start, end):
prefix = ">" if j == i else " "
context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}")
results.append({
"file": relative_path,
"line": i + 1,
"match": line.strip()[:200],
"context": '\n'.join(context_lines),
})
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception:
continue
if len(results) >= max_results:
break
if not results:
return ToolResult(
success=True,
data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件",
metadata={"files_searched": files_searched, "matches": 0}
)
# 格式化输出
output_parts = [f"🔍 搜索结果: '{keyword}'\n"]
output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n")
for result in results:
output_parts.append(f"\n📄 {result['file']}:{result['line']}")
output_parts.append(f"```\n{result['context']}\n```")
if len(results) >= max_results:
output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"keyword": keyword,
"files_searched": files_searched,
"matches": len(results),
"results": results[:10], # 只在元数据中保留前10个
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"搜索失败: {str(e)}",
)
class ListFilesInput(BaseModel):
"""列出文件输入"""
directory: str = Field(default=".", description="目录路径(相对于项目根目录)")
pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py")
recursive: bool = Field(default=False, description="是否递归列出子目录")
max_files: int = Field(default=100, description="最大文件数")
class ListFilesTool(AgentTool):
"""
列出文件工具
列出目录中的文件
"""
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage",
}
def __init__(self, project_root: str):
super().__init__()
self.project_root = project_root
@property
def name(self) -> str:
return "list_files"
@property
def description(self) -> str:
return """列出目录中的文件。
使用场景:
- 了解项目结构
- 查找特定类型的文件
- 浏览目录内容
输入:
- directory: 目录路径
- pattern: 可选文件名模式
- recursive: 是否递归
- max_files: 最大文件数"""
@property
def args_schema(self):
return ListFilesInput
async def _execute(
self,
directory: str = ".",
pattern: Optional[str] = None,
recursive: bool = False,
max_files: int = 100,
**kwargs
) -> ToolResult:
"""执行文件列表"""
try:
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not target_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的目录",
)
if not os.path.exists(target_dir):
return ToolResult(
success=False,
error=f"目录不存在: {directory}",
)
files = []
dirs = []
if recursive:
for root, dirnames, filenames in os.walk(target_dir):
# 排除目录
dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS]
for filename in filenames:
if pattern and not fnmatch.fnmatch(filename, pattern):
continue
full_path = os.path.join(root, filename)
relative_path = os.path.relpath(full_path, self.project_root)
files.append(relative_path)
if len(files) >= max_files:
break
if len(files) >= max_files:
break
else:
for item in os.listdir(target_dir):
if item in self.EXCLUDE_DIRS:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
files.append(relative_path)
if len(files) >= max_files:
break
# 格式化输出
output_parts = [f"📁 目录: {directory}\n"]
if dirs:
output_parts.append("目录:")
for d in sorted(dirs)[:20]:
output_parts.append(f" 📂 {d}")
if len(dirs) > 20:
output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录")
if files:
output_parts.append(f"\n文件 ({len(files)}):")
for f in sorted(files):
output_parts.append(f" 📄 {f}")
if len(files) >= max_files:
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"directory": directory,
"file_count": len(files),
"dir_count": len(dirs),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"列出文件失败: {str(e)}",
)