""" 文件操作工具 读取和搜索代码文件 """ import os import re import fnmatch import asyncio from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field from .base import AgentTool, ToolResult from app.core.file_filter import TEXT_EXTENSIONS, EXCLUDE_DIRS, EXCLUDE_FILES, is_text_file, should_exclude class FileReadInput(BaseModel): """文件读取输入""" file_path: str = Field(description="文件路径(相对于项目根目录)") start_line: Optional[int] = Field(default=None, description="起始行号(从1开始)") end_line: Optional[int] = Field(default=None, description="结束行号") max_lines: int = Field(default=500, description="最大返回行数") class FileReadTool(AgentTool): """ 文件读取工具 读取项目中的文件内容 """ def __init__( self, project_root: str, exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, ): """ 初始化文件读取工具 Args: project_root: 项目根目录 exclude_patterns: 排除模式列表 target_files: 目标文件列表(如果指定,只允许读取这些文件) """ super().__init__() self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None @staticmethod def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple: """同步读取文件指定行范围(用于 asyncio.to_thread)""" selected_lines = [] total_lines = 0 file_size = os.path.getsize(file_path) with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: for i, line in enumerate(f): total_lines = i + 1 if i >= start_idx and i < end_idx: selected_lines.append(line) elif i >= end_idx: if i < end_idx + 1000: continue else: remaining_bytes = file_size - f.tell() avg_line_size = f.tell() / (i + 1) estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0 total_lines = i + 1 + estimated_remaining_lines break return selected_lines, total_lines @staticmethod def _read_all_lines_sync(file_path: str) -> List[str]: """同步读取文件所有行(用于 asyncio.to_thread)""" with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: return f.readlines() @property def name(self) -> str: return "read_file" @property def description(self) -> str: return """读取项目中的文件内容。 使用场景: - 查看完整的源代码文件 - 查看特定行范围的代码 - 获取配置文件内容 输入: - file_path: 文件路径(相对于项目根目录) - start_line: 可选,起始行号 - end_line: 可选,结束行号 - max_lines: 最大返回行数(默认500) 注意: 为避免输出过长,建议指定行范围或使用 RAG 搜索定位代码。""" @property def args_schema(self): return FileReadInput def _should_exclude(self, file_path: str) -> bool: """检查文件是否应该被排除""" # 如果指定了目标文件,只允许读取这些文件 if self.target_files and file_path not in self.target_files: return True # 使用通用的筛选逻辑 filename = os.path.basename(file_path) return should_exclude(file_path, filename, self.exclude_patterns) async def _execute( self, file_path: str, start_line: Optional[int] = None, end_line: Optional[int] = None, max_lines: int = 500, **kwargs ) -> ToolResult: """执行文件读取""" try: # 检查是否被排除 if self._should_exclude(file_path): return ToolResult( success=False, error=f"文件被排除或不在目标文件列表中: {file_path}", ) # 安全检查:防止路径遍历 full_path = os.path.normpath(os.path.join(self.project_root, file_path)) if not full_path.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许访问项目目录外的文件", ) if not os.path.exists(full_path): return ToolResult( success=False, error=f"文件不存在: {file_path}", ) if not os.path.isfile(full_path): return ToolResult( success=False, error=f"不是文件: {file_path}", ) # 检查文件大小 file_size = os.path.getsize(full_path) is_large_file = file_size > 1024 * 1024 # 1MB # 🔥 修复:如果指定了行范围,允许读取大文件的部分内容 if is_large_file and start_line is None and end_line is None: return ToolResult( success=False, error=f"文件过大 ({file_size / 1024:.1f}KB),请指定 start_line 和 end_line 读取部分内容", ) # 🔥 对于大文件,使用流式读取指定行范围 if is_large_file and (start_line is not None or end_line is not None): # 计算实际的起始和结束行 start_idx = max(0, (start_line or 1) - 1) end_idx = end_line if end_line else start_idx + max_lines # 异步读取文件,避免阻塞事件循环 selected_lines, total_lines = await asyncio.to_thread( self._read_file_lines_sync, full_path, start_idx, end_idx ) # 更新实际的结束索引 end_idx = min(end_idx, start_idx + len(selected_lines)) else: # 异步读取小文件,避免阻塞事件循环 lines = await asyncio.to_thread(self._read_all_lines_sync, full_path) total_lines = len(lines) # 处理行范围 if start_line is not None: start_idx = max(0, start_line - 1) else: start_idx = 0 if end_line is not None: end_idx = min(total_lines, end_line) else: end_idx = min(total_lines, start_idx + max_lines) # 截取指定行 selected_lines = lines[start_idx:end_idx] # 添加行号 numbered_lines = [] for i, line in enumerate(selected_lines, start=start_idx + 1): numbered_lines.append(f"{i:4d}| {line.rstrip()}") content = '\n'.join(numbered_lines) # 检测语言 ext = os.path.splitext(file_path)[1].lower() language = { ".py": "python", ".js": "javascript", ".ts": "typescript", ".java": "java", ".go": "go", ".rs": "rust", ".cpp": "cpp", ".c": "c", ".cs": "csharp", ".php": "php", ".rb": "ruby", ".swift": "swift", }.get(ext, "text") output = f"📄 文件: {file_path}\n" output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n" output += f"```{language}\n{content}\n```" if end_idx < total_lines: output += f"\n\n... 还有 {total_lines - end_idx} 行未显示" return ToolResult( success=True, data=output, metadata={ "file_path": file_path, "total_lines": total_lines, "start_line": start_idx + 1, "end_line": end_idx, "language": language, } ) except Exception as e: return ToolResult( success=False, error=f"读取文件失败: {str(e)}", ) class FileSearchInput(BaseModel): """文件搜索输入""" keyword: str = Field(description="搜索关键字或正则表达式") file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js") directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)") case_sensitive: bool = Field(default=False, description="是否区分大小写") max_results: int = Field(default=50, description="最大结果数") is_regex: bool = Field(default=False, description="是否使用正则表达式") class FileSearchTool(AgentTool): """ 文件搜索工具 在项目中搜索包含特定内容的代码 """ # 排除的目录 DEFAULT_EXCLUDE_DIRS = { "node_modules", "vendor", "dist", "build", ".git", "__pycache__", ".pytest_cache", "coverage", ".nyc_output", ".vscode", ".idea", ".vs", "target", "venv", "env", } def __init__( self, project_root: str, exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, ): super().__init__() self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None # 从 exclude_patterns 中提取目录排除 self.exclude_dirs = set(EXCLUDE_DIRS) for pattern in self.exclude_patterns: if pattern.endswith("/**"): self.exclude_dirs.add(pattern[:-3]) elif "/" not in pattern and "*" not in pattern: self.exclude_dirs.add(pattern) @staticmethod def _read_file_lines_sync(file_path: str) -> List[str]: """同步读取文件所有行(用于 asyncio.to_thread)""" with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: return f.readlines() @property def name(self) -> str: return "search_code" @property def description(self) -> str: return """在项目代码中搜索关键字或模式。 使用场景: - 查找特定函数的所有调用位置 - 搜索特定的 API 使用 - 查找包含特定模式的代码 输入: - keyword: 搜索关键字或正则表达式 - file_pattern: 可选,文件名模式(如 *.py) - directory: 可选,搜索目录 - case_sensitive: 是否区分大小写 - is_regex: 是否使用正则表达式 这是一个快速搜索工具,结果包含匹配行和上下文。""" @property def args_schema(self): return FileSearchInput async def _execute( self, keyword: str, file_pattern: Optional[str] = None, directory: Optional[str] = None, case_sensitive: bool = False, max_results: int = 50, is_regex: bool = False, **kwargs ) -> ToolResult: """执行文件搜索""" try: # 确定搜索目录 if directory: search_dir = os.path.normpath(os.path.join(self.project_root, directory)) if not search_dir.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许搜索项目目录外的内容", ) else: search_dir = self.project_root # 编译搜索模式 flags = 0 if case_sensitive else re.IGNORECASE try: if is_regex: pattern = re.compile(keyword, flags) else: pattern = re.compile(re.escape(keyword), flags) except re.error as e: return ToolResult( success=False, error=f"无效的搜索模式: {e}", ) results = [] files_searched = 0 # 遍历文件 for root, dirs, files in os.walk(search_dir): # 排除目录 dirs[:] = [d for d in dirs if d not in self.exclude_dirs] for filename in files: # 检查文件模式 if file_pattern and not fnmatch.fnmatch(filename, file_pattern): continue file_path = os.path.join(root, filename) relative_path = os.path.relpath(file_path, self.project_root) # 检查是否在目标文件列表中 if self.target_files and relative_path not in self.target_files: continue # 使用通用的筛选逻辑 if not is_text_file(filename) or should_exclude(relative_path, filename, self.exclude_patterns): continue try: # 异步读取文件,避免阻塞事件循环 lines = await asyncio.to_thread( self._read_file_lines_sync, file_path ) files_searched += 1 for i, line in enumerate(lines): if pattern.search(line): # 获取上下文 start = max(0, i - 1) end = min(len(lines), i + 2) context_lines = [] for j in range(start, end): prefix = ">" if j == i else " " context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}") results.append({ "file": relative_path, "line": i + 1, "match": line.strip()[:200], "context": '\n'.join(context_lines), }) if len(results) >= max_results: break if len(results) >= max_results: break except Exception: continue if len(results) >= max_results: break if not results: return ToolResult( success=True, data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件", metadata={"files_searched": files_searched, "matches": 0} ) # 格式化输出 output_parts = [f"🔍 搜索结果: '{keyword}'\n"] output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n") for result in results: output_parts.append(f"\n📄 {result['file']}:{result['line']}") output_parts.append(f"```\n{result['context']}\n```") if len(results) >= max_results: output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "keyword": keyword, "files_searched": files_searched, "matches": len(results), "results": results[:10], # 只在元数据中保留前10个 } ) except Exception as e: return ToolResult( success=False, error=f"搜索失败: {str(e)}", ) class ListFilesInput(BaseModel): """列出文件输入""" directory: str = Field(default=".", description="目录路径(相对于项目根目录)") pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py") recursive: bool = Field(default=False, description="是否递归列出子目录") max_files: int = Field(default=100, description="最大文件数") class ListFilesTool(AgentTool): """ 列出文件工具 列出目录中的文件 """ def __init__( self, project_root: str, exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, ): super().__init__() self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None # 使用通用的排除目录 self.exclude_dirs = set(EXCLUDE_DIRS) for pattern in self.exclude_patterns: # 如果是目录模式(如 node_modules/**),提取目录名 if pattern.endswith("/**"): self.exclude_dirs.add(pattern[:-3]) elif "/" not in pattern and "*" not in pattern: self.exclude_dirs.add(pattern) @property def name(self) -> str: return "list_files" @property def description(self) -> str: return """列出目录中的文件。 使用场景: - 了解项目结构 - 查找特定类型的文件 - 浏览目录内容 输入: - directory: 目录路径 - pattern: 可选,文件名模式 - recursive: 是否递归 - max_files: 最大文件数""" @property def args_schema(self): return ListFilesInput async def _execute( self, directory: str = ".", pattern: Optional[str] = None, recursive: bool = False, max_files: int = 100, **kwargs ) -> ToolResult: """执行文件列表""" try: # 🔥 兼容性处理:支持 path 参数作为 directory 的别名 if "path" in kwargs and kwargs["path"]: directory = kwargs["path"] target_dir = os.path.normpath(os.path.join(self.project_root, directory)) if not target_dir.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许访问项目目录外的目录", ) if not os.path.exists(target_dir): return ToolResult( success=False, error=f"目录不存在: {directory}", ) files = [] dirs = [] if recursive: for root, dirnames, filenames in os.walk(target_dir): # 排除目录 dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs] for filename in filenames: if pattern and not fnmatch.fnmatch(filename, pattern): continue full_path = os.path.join(root, filename) relative_path = os.path.relpath(full_path, self.project_root) # 检查是否在目标文件列表中 if self.target_files and relative_path not in self.target_files: continue # 使用通用的筛选逻辑 if should_exclude(relative_path, filename, self.exclude_patterns): continue files.append(relative_path) if len(files) >= max_files: break if len(files) >= max_files: break else: # 🔥 如果设置了 target_files,只显示目标文件和包含目标文件的目录 if self.target_files: # 计算哪些目录包含目标文件 dirs_with_targets = set() for tf in self.target_files: # 获取目标文件的目录部分 tf_dir = os.path.dirname(tf) while tf_dir: dirs_with_targets.add(tf_dir) tf_dir = os.path.dirname(tf_dir) for item in os.listdir(target_dir): if item in self.exclude_dirs: continue full_path = os.path.join(target_dir, item) relative_path = os.path.relpath(full_path, self.project_root) if os.path.isdir(full_path): # 只显示包含目标文件的目录 if relative_path in dirs_with_targets or any( tf.startswith(relative_path + "/") for tf in self.target_files ): dirs.append(relative_path + "/") else: if pattern and not fnmatch.fnmatch(item, pattern): continue # 检查是否在目标文件列表中 if relative_path not in self.target_files: continue files.append(relative_path) if len(files) >= max_files: break else: # 没有设置 target_files,正常列出 for item in os.listdir(target_dir): if item in self.exclude_dirs: continue full_path = os.path.join(target_dir, item) relative_path = os.path.relpath(full_path, self.project_root) if os.path.isdir(full_path): dirs.append(relative_path + "/") else: if pattern and not fnmatch.fnmatch(item, pattern): continue # 使用通用的筛选逻辑 if should_exclude(relative_path, item, self.exclude_patterns): continue files.append(relative_path) if len(files) >= max_files: break # 格式化输出 output_parts = [f"📁 目录: {directory}\n"] # 🔥 如果设置了 target_files,显示提示信息 if self.target_files: output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n") if dirs: output_parts.append("目录:") for d in sorted(dirs)[:20]: output_parts.append(f" 📂 {d}") if len(dirs) > 20: output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录") if files: output_parts.append(f"\n文件 ({len(files)}):") for f in sorted(files): output_parts.append(f" 📄 {f}") elif self.target_files: # 如果没有文件但设置了 target_files,显示目标文件列表 output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):") for f in sorted(self.target_files)[:20]: output_parts.append(f" 📄 {f}") if len(self.target_files) > 20: output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件") if len(files) >= max_files: output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "directory": directory, "file_count": len(files), "dir_count": len(dirs), } ) except Exception as e: return ToolResult( success=False, error=f"列出文件失败: {str(e)}", )