CodeReview/backend/app/services/agent/tools/file_tool.py

665 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文件操作工具
读取和搜索代码文件
"""
import os
import re
import fnmatch
import asyncio
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
from app.core.file_filter import TEXT_EXTENSIONS, EXCLUDE_DIRS, EXCLUDE_FILES, is_text_file, should_exclude
class FileReadInput(BaseModel):
"""文件读取输入"""
file_path: str = Field(description="文件路径(相对于项目根目录)")
start_line: Optional[int] = Field(default=None, description="起始行号从1开始")
end_line: Optional[int] = Field(default=None, description="结束行号")
max_lines: int = Field(default=500, description="最大返回行数")
class FileReadTool(AgentTool):
"""
文件读取工具
读取项目中的文件内容
"""
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
"""
初始化文件读取工具
Args:
project_root: 项目根目录
exclude_patterns: 排除模式列表
target_files: 目标文件列表(如果指定,只允许读取这些文件)
"""
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
@staticmethod
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
"""同步读取文件指定行范围(用于 asyncio.to_thread"""
selected_lines = []
total_lines = 0
file_size = os.path.getsize(file_path)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
if i < end_idx + 1000:
continue
else:
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
return selected_lines, total_lines
@staticmethod
def _read_all_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
return """读取项目中的文件内容。
使用场景:
- 查看完整的源代码文件
- 查看特定行范围的代码
- 获取配置文件内容
输入:
- file_path: 文件路径(相对于项目根目录)
- start_line: 可选,起始行号
- end_line: 可选,结束行号
- max_lines: 最大返回行数默认500
注意: 为避免输出过长,建议指定行范围或使用 RAG 搜索定位代码。"""
@property
def args_schema(self):
return FileReadInput
def _should_exclude(self, file_path: str) -> bool:
"""检查文件是否应该被排除"""
# 如果指定了目标文件,只允许读取这些文件
if self.target_files and file_path not in self.target_files:
return True
# 使用通用的筛选逻辑
filename = os.path.basename(file_path)
return should_exclude(file_path, filename, self.exclude_patterns)
async def _execute(
self,
file_path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: int = 500,
**kwargs
) -> ToolResult:
"""执行文件读取"""
try:
# 检查是否被排除
if self._should_exclude(file_path):
return ToolResult(
success=False,
error=f"文件被排除或不在目标文件列表中: {file_path}",
)
# 安全检查:防止路径遍历
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
if not full_path.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的文件",
)
if not os.path.exists(full_path):
return ToolResult(
success=False,
error=f"文件不存在: {file_path}",
)
if not os.path.isfile(full_path):
return ToolResult(
success=False,
error=f"不是文件: {file_path}",
)
# 检查文件大小
file_size = os.path.getsize(full_path)
is_large_file = file_size > 1024 * 1024 # 1MB
# 🔥 修复:如果指定了行范围,允许读取大文件的部分内容
if is_large_file and start_line is None and end_line is None:
return ToolResult(
success=False,
error=f"文件过大 ({file_size / 1024:.1f}KB),请指定 start_line 和 end_line 读取部分内容",
)
# 🔥 对于大文件,使用流式读取指定行范围
if is_large_file and (start_line is not None or end_line is not None):
# 计算实际的起始和结束行
start_idx = max(0, (start_line or 1) - 1)
end_idx = end_line if end_line else start_idx + max_lines
# 异步读取文件,避免阻塞事件循环
selected_lines, total_lines = await asyncio.to_thread(
self._read_file_lines_sync, full_path, start_idx, end_idx
)
# 更新实际的结束索引
end_idx = min(end_idx, start_idx + len(selected_lines))
else:
# 异步读取小文件,避免阻塞事件循环
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
total_lines = len(lines)
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行
selected_lines = lines[start_idx:end_idx]
# 添加行号
numbered_lines = []
for i, line in enumerate(selected_lines, start=start_idx + 1):
numbered_lines.append(f"{i:4d}| {line.rstrip()}")
content = '\n'.join(numbered_lines)
# 检测语言
ext = os.path.splitext(file_path)[1].lower()
language = {
".py": "python", ".js": "javascript", ".ts": "typescript",
".java": "java", ".go": "go", ".rs": "rust",
".cpp": "cpp", ".c": "c", ".cs": "csharp",
".php": "php", ".rb": "ruby", ".swift": "swift",
}.get(ext, "text")
output = f"📄 文件: {file_path}\n"
output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n"
output += f"```{language}\n{content}\n```"
if end_idx < total_lines:
output += f"\n\n... 还有 {total_lines - end_idx} 行未显示"
return ToolResult(
success=True,
data=output,
metadata={
"file_path": file_path,
"total_lines": total_lines,
"start_line": start_idx + 1,
"end_line": end_idx,
"language": language,
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"读取文件失败: {str(e)}",
)
class FileSearchInput(BaseModel):
"""文件搜索输入"""
keyword: str = Field(description="搜索关键字或正则表达式")
file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js")
directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)")
case_sensitive: bool = Field(default=False, description="是否区分大小写")
max_results: int = Field(default=50, description="最大结果数")
is_regex: bool = Field(default=False, description="是否使用正则表达式")
class FileSearchTool(AgentTool):
"""
文件搜索工具
在项目中搜索包含特定内容的代码
"""
# 排除的目录
DEFAULT_EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "venv", "env",
}
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@staticmethod
def _read_file_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "search_code"
@property
def description(self) -> str:
return """在项目代码中搜索关键字或模式。
使用场景:
- 查找特定函数的所有调用位置
- 搜索特定的 API 使用
- 查找包含特定模式的代码
输入:
- keyword: 搜索关键字或正则表达式
- file_pattern: 可选,文件名模式(如 *.py
- directory: 可选,搜索目录
- case_sensitive: 是否区分大小写
- is_regex: 是否使用正则表达式
这是一个快速搜索工具,结果包含匹配行和上下文。"""
@property
def args_schema(self):
return FileSearchInput
async def _execute(
self,
keyword: str,
file_pattern: Optional[str] = None,
directory: Optional[str] = None,
case_sensitive: bool = False,
max_results: int = 50,
is_regex: bool = False,
**kwargs
) -> ToolResult:
"""执行文件搜索"""
try:
# 确定搜索目录
if directory:
search_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not search_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许搜索项目目录外的内容",
)
else:
search_dir = self.project_root
# 编译搜索模式
flags = 0 if case_sensitive else re.IGNORECASE
try:
if is_regex:
pattern = re.compile(keyword, flags)
else:
pattern = re.compile(re.escape(keyword), flags)
except re.error as e:
return ToolResult(
success=False,
error=f"无效的搜索模式: {e}",
)
results = []
files_searched = 0
# 遍历文件
for root, dirs, files in os.walk(search_dir):
# 排除目录
dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
for filename in files:
# 检查文件模式
if file_pattern and not fnmatch.fnmatch(filename, file_pattern):
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 使用通用的筛选逻辑
if not is_text_file(filename) or should_exclude(relative_path, filename, self.exclude_patterns):
continue
try:
# 异步读取文件,避免阻塞事件循环
lines = await asyncio.to_thread(
self._read_file_lines_sync, file_path
)
files_searched += 1
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文
start = max(0, i - 1)
end = min(len(lines), i + 2)
context_lines = []
for j in range(start, end):
prefix = ">" if j == i else " "
context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}")
results.append({
"file": relative_path,
"line": i + 1,
"match": line.strip()[:200],
"context": '\n'.join(context_lines),
})
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception:
continue
if len(results) >= max_results:
break
if not results:
return ToolResult(
success=True,
data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件",
metadata={"files_searched": files_searched, "matches": 0}
)
# 格式化输出
output_parts = [f"🔍 搜索结果: '{keyword}'\n"]
output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n")
for result in results:
output_parts.append(f"\n📄 {result['file']}:{result['line']}")
output_parts.append(f"```\n{result['context']}\n```")
if len(results) >= max_results:
output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"keyword": keyword,
"files_searched": files_searched,
"matches": len(results),
"results": results[:10], # 只在元数据中保留前10个
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"搜索失败: {str(e)}",
)
class ListFilesInput(BaseModel):
"""列出文件输入"""
directory: str = Field(default=".", description="目录路径(相对于项目根目录)")
pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py")
recursive: bool = Field(default=False, description="是否递归列出子目录")
max_files: int = Field(default=100, description="最大文件数")
class ListFilesTool(AgentTool):
"""
列出文件工具
列出目录中的文件
"""
def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__()
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 使用通用的排除目录
self.exclude_dirs = set(EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
# 如果是目录模式(如 node_modules/**),提取目录名
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@property
def name(self) -> str:
return "list_files"
@property
def description(self) -> str:
return """列出目录中的文件。
使用场景:
- 了解项目结构
- 查找特定类型的文件
- 浏览目录内容
输入:
- directory: 目录路径
- pattern: 可选,文件名模式
- recursive: 是否递归
- max_files: 最大文件数"""
@property
def args_schema(self):
return ListFilesInput
async def _execute(
self,
directory: str = ".",
pattern: Optional[str] = None,
recursive: bool = False,
max_files: int = 100,
**kwargs
) -> ToolResult:
"""执行文件列表"""
try:
# 🔥 兼容性处理:支持 path 参数作为 directory 的别名
if "path" in kwargs and kwargs["path"]:
directory = kwargs["path"]
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not target_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(
success=False,
error="安全错误:不允许访问项目目录外的目录",
)
if not os.path.exists(target_dir):
return ToolResult(
success=False,
error=f"目录不存在: {directory}",
)
files = []
dirs = []
if recursive:
for root, dirnames, filenames in os.walk(target_dir):
# 排除目录
dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs]
for filename in filenames:
if pattern and not fnmatch.fnmatch(filename, pattern):
continue
full_path = os.path.join(root, filename)
relative_path = os.path.relpath(full_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 使用通用的筛选逻辑
if should_exclude(relative_path, filename, self.exclude_patterns):
continue
files.append(relative_path)
if len(files) >= max_files:
break
if len(files) >= max_files:
break
else:
# 🔥 如果设置了 target_files只显示目标文件和包含目标文件的目录
if self.target_files:
# 计算哪些目录包含目标文件
dirs_with_targets = set()
for tf in self.target_files:
# 获取目标文件的目录部分
tf_dir = os.path.dirname(tf)
while tf_dir:
dirs_with_targets.add(tf_dir)
tf_dir = os.path.dirname(tf_dir)
for item in os.listdir(target_dir):
if item in self.exclude_dirs:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
# 只显示包含目标文件的目录
if relative_path in dirs_with_targets or any(
tf.startswith(relative_path + "/") for tf in self.target_files
):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
# 检查是否在目标文件列表中
if relative_path not in self.target_files:
continue
files.append(relative_path)
if len(files) >= max_files:
break
else:
# 没有设置 target_files正常列出
for item in os.listdir(target_dir):
if item in self.exclude_dirs:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
# 使用通用的筛选逻辑
if should_exclude(relative_path, item, self.exclude_patterns):
continue
files.append(relative_path)
if len(files) >= max_files:
break
# 格式化输出
output_parts = [f"📁 目录: {directory}\n"]
# 🔥 如果设置了 target_files显示提示信息
if self.target_files:
output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n")
if dirs:
output_parts.append("目录:")
for d in sorted(dirs)[:20]:
output_parts.append(f" 📂 {d}")
if len(dirs) > 20:
output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录")
if files:
output_parts.append(f"\n文件 ({len(files)}):")
for f in sorted(files):
output_parts.append(f" 📄 {f}")
elif self.target_files:
# 如果没有文件但设置了 target_files显示目标文件列表
output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):")
for f in sorted(self.target_files)[:20]:
output_parts.append(f" 📄 {f}")
if len(self.target_files) > 20:
output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件")
if len(files) >= max_files:
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata={
"directory": directory,
"file_count": len(files),
"dir_count": len(dirs),
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"列出文件失败: {str(e)}",
)