CodeReview/backend/app/services/agent/tools/file_tool.py

631 lines
24 KiB
Python
Raw Normal View History

"""
文件操作工具
读取和搜索代码文件
"""
import os
import re
import fnmatch
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
class FileReadInput(BaseModel):
"""文件读取输入"""
file_path: str = Field(description="文件路径(相对于项目根目录)")
start_line: Optional[int] = Field(default=None, description="起始行号从1开始")
end_line: Optional[int] = Field(default=None, description="结束行号")
max_lines: int = Field(default=500, description="最大返回行数")
class FileReadTool(AgentTool):
"""
文件读取工具
读取项目中的文件内容
"""
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
"""
初始化文件读取工具
Args:
project_root: 项目根目录
exclude_patterns: 排除模式列表
target_files: 目标文件列表如果指定只允许读取这些文件
"""
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
return """读取项目中的文件内容。
使用场景:
- 查看完整的源代码文件
- 查看特定行范围的代码
- 获取配置文件内容
输入:
- file_path: 文件路径相对于项目根目录
- start_line: 可选起始行号
- end_line: 可选结束行号
- max_lines: 最大返回行数默认500
注意: 为避免输出过长建议指定行范围或使用 RAG 搜索定位代码"""
@property
def args_schema(self):
return FileReadInput
def _should_exclude(self, file_path: str) -> bool:
"""检查文件是否应该被排除"""
# 如果指定了目标文件,只允许读取这些文件
if self.target_files and file_path not in self.target_files:
return True
# 检查排除模式
for pattern in self.exclude_patterns:
if fnmatch.fnmatch(file_path, pattern):
return True
# 也检查文件名
if fnmatch.fnmatch(os.path.basename(file_path), pattern):
return True
return False
async def _execute(
self,
file_path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: int = 500,
**kwargs
) -> ToolResult:
"""执行文件读取"""
try:
# 检查是否被排除
if self._should_exclude(file_path):
return ToolResult(
success=False,
error=f"文件被排除或不在目标文件列表中: {file_path}",
)
# 安全检查:防止路径遍历
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的文件",
)
if not os.path.exists(full_path):
return ToolResult(
success=False,
error=f"文件不存在: {file_path}",
)
if not os.path.isfile(full_path):
return ToolResult(
success=False,
error=f"不是文件: {file_path}",
)
# 检查文件大小
file_size = os.path.getsize(full_path)
if file_size > 1024 * 1024: # 1MB
return ToolResult(
success=False,
error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围",
)
# 读取文件
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
total_lines = len(lines)
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行
selected_lines = lines[start_idx:end_idx]
# 添加行号
numbered_lines = []
for i, line in enumerate(selected_lines, start=start_idx + 1):
numbered_lines.append(f"{i:4d}| {line.rstrip()}")
content = '\n'.join(numbered_lines)
# 检测语言
ext = os.path.splitext(file_path)[1].lower()
language = {
".py": "python", ".js": "javascript", ".ts": "typescript",
".java": "java", ".go": "go", ".rs": "rust",
".cpp": "cpp", ".c": "c", ".cs": "csharp",
".php": "php", ".rb": "ruby", ".swift": "swift",
}.get(ext, "text")
output = f"📄 文件: {file_path}\n"
output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n"
output += f"```{language}\n{content}\n```"
if end_idx < total_lines:
output += f"\n\n... 还有 {total_lines - end_idx} 行未显示"
return ToolResult(
success=True,
data=output,
metadata={
"file_path": file_path,
"total_lines": total_lines,
"start_line": start_idx + 1,
"end_line": end_idx,
"language": language,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"读取文件失败: {str(e)}",
)
class FileSearchInput(BaseModel):
"""文件搜索输入"""
keyword: str = Field(description="搜索关键字或正则表达式")
file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js")
directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)")
case_sensitive: bool = Field(default=False, description="是否区分大小写")
max_results: int = Field(default=50, description="最大结果数")
is_regex: bool = Field(default=False, description="是否使用正则表达式")
class FileSearchTool(AgentTool):
"""
文件搜索工具
在项目中搜索包含特定内容的代码
"""
# 排除的目录
DEFAULT_EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "venv", "env",
}
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@property
def name(self) -> str:
return "search_code"
@property
def description(self) -> str:
return """在项目代码中搜索关键字或模式。
使用场景:
- 查找特定函数的所有调用位置
- 搜索特定的 API 使用
- 查找包含特定模式的代码
输入:
- keyword: 搜索关键字或正则表达式
- file_pattern: 可选文件名模式 *.py
- directory: 可选搜索目录
- case_sensitive: 是否区分大小写
- is_regex: 是否使用正则表达式
这是一个快速搜索工具结果包含匹配行和上下文"""
@property
def args_schema(self):
return FileSearchInput
async def _execute(
self,
keyword: str,
file_pattern: Optional[str] = None,
directory: Optional[str] = None,
case_sensitive: bool = False,
max_results: int = 50,
is_regex: bool = False,
**kwargs
) -> ToolResult:
"""执行文件搜索"""
try:
# 确定搜索目录
if directory:
search_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not search_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许搜索项目目录外的内容",
)
else:
search_dir = self.project_root
# 编译搜索模式
flags = 0 if case_sensitive else re.IGNORECASE
try:
if is_regex:
pattern = re.compile(keyword, flags)
else:
pattern = re.compile(re.escape(keyword), flags)
except re.error as e:
return ToolResult(
success=False,
error=f"无效的搜索模式: {e}",
)
results = []
files_searched = 0
# 遍历文件
for root, dirs, files in os.walk(search_dir):
# 排除目录
dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
for filename in files:
# 检查文件模式
if file_pattern and not fnmatch.fnmatch(filename, file_pattern):
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
should_skip = True
break
if should_skip:
continue
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
files_searched += 1
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文
start = max(0, i - 1)
end = min(len(lines), i + 2)
context_lines = []
for j in range(start, end):
prefix = ">" if j == i else " "
context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}")
results.append({
"file": relative_path,
"line": i + 1,
"match": line.strip()[:200],
"context": '\n'.join(context_lines),
})
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception:
continue
if len(results) >= max_results:
break
if not results:
return ToolResult(
success=True,
data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件",
metadata={"files_searched": files_searched, "matches": 0}
)
# 格式化输出
output_parts = [f"🔍 搜索结果: '{keyword}'\n"]
output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n")
for result in results:
output_parts.append(f"\n📄 {result['file']}:{result['line']}")
output_parts.append(f"```\n{result['context']}\n```")
if len(results) >= max_results:
output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"keyword": keyword,
"files_searched": files_searched,
"matches": len(results),
"results": results[:10], # 只在元数据中保留前10个
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"搜索失败: {str(e)}",
)
class ListFilesInput(BaseModel):
"""列出文件输入"""
directory: str = Field(default=".", description="目录路径(相对于项目根目录)")
pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py")
recursive: bool = Field(default=False, description="是否递归列出子目录")
max_files: int = Field(default=100, description="最大文件数")
class ListFilesTool(AgentTool):
"""
列出文件工具
列出目录中的文件
"""
DEFAULT_EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage",
}
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
# 如果是目录模式(如 node_modules/**),提取目录名
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@property
def name(self) -> str:
return "list_files"
@property
def description(self) -> str:
return """列出目录中的文件。
使用场景:
- 了解项目结构
- 查找特定类型的文件
- 浏览目录内容
输入:
- directory: 目录路径
- pattern: 可选文件名模式
- recursive: 是否递归
- max_files: 最大文件数"""
@property
def args_schema(self):
return ListFilesInput
async def _execute(
self,
directory: str = ".",
pattern: Optional[str] = None,
recursive: bool = False,
max_files: int = 100,
**kwargs
) -> ToolResult:
"""执行文件列表"""
try:
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not target_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的目录",
)
if not os.path.exists(target_dir):
return ToolResult(
success=False,
error=f"目录不存在: {directory}",
)
files = []
dirs = []
if recursive:
for root, dirnames, filenames in os.walk(target_dir):
# 排除目录
dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs]
for filename in filenames:
if pattern and not fnmatch.fnmatch(filename, pattern):
continue
full_path = os.path.join(root, filename)
relative_path = os.path.relpath(full_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
should_skip = True
break
if should_skip:
continue
files.append(relative_path)
if len(files) >= max_files:
break
if len(files) >= max_files:
break
else:
# 🔥 如果设置了 target_files只显示目标文件和包含目标文件的目录
if self.target_files:
# 计算哪些目录包含目标文件
dirs_with_targets = set()
for tf in self.target_files:
# 获取目标文件的目录部分
tf_dir = os.path.dirname(tf)
while tf_dir:
dirs_with_targets.add(tf_dir)
tf_dir = os.path.dirname(tf_dir)
for item in os.listdir(target_dir):
if item in self.exclude_dirs:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
# 只显示包含目标文件的目录
if relative_path in dirs_with_targets or any(
tf.startswith(relative_path + "/") for tf in self.target_files
):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
# 检查是否在目标文件列表中
if relative_path not in self.target_files:
continue
files.append(relative_path)
if len(files) >= max_files:
break
else:
# 没有设置 target_files正常列出
for item in os.listdir(target_dir):
if item in self.exclude_dirs:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(item, excl_pattern):
should_skip = True
break
if should_skip:
continue
files.append(relative_path)
if len(files) >= max_files:
break
# 格式化输出
output_parts = [f"📁 目录: {directory}\n"]
# 🔥 如果设置了 target_files显示提示信息
if self.target_files:
output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n")
if dirs:
output_parts.append("目录:")
for d in sorted(dirs)[:20]:
output_parts.append(f" 📂 {d}")
if len(dirs) > 20:
output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录")
if files:
output_parts.append(f"\n文件 ({len(files)}):")
for f in sorted(files):
output_parts.append(f" 📄 {f}")
elif self.target_files:
# 如果没有文件但设置了 target_files显示目标文件列表
output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):")
for f in sorted(self.target_files)[:20]:
output_parts.append(f" 📄 {f}")
if len(self.target_files) > 20:
output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件")
if len(files) >= max_files:
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"directory": directory,
"file_count": len(files),
"dir_count": len(dirs),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"列出文件失败: {str(e)}",
)