""" 文件操作工具 读取和搜索代码文件 """ import os import re import fnmatch from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field from .base import AgentTool, ToolResult class FileReadInput(BaseModel): """文件读取输入""" file_path: str = Field(description="文件路径(相对于项目根目录)") start_line: Optional[int] = Field(default=None, description="起始行号(从1开始)") end_line: Optional[int] = Field(default=None, description="结束行号") max_lines: int = Field(default=500, description="最大返回行数") class FileReadTool(AgentTool): """ 文件读取工具 读取项目中的文件内容 """ def __init__(self, project_root: str): """ 初始化文件读取工具 Args: project_root: 项目根目录 """ super().__init__() self.project_root = project_root @property def name(self) -> str: return "read_file" @property def description(self) -> str: return """读取项目中的文件内容。 使用场景: - 查看完整的源代码文件 - 查看特定行范围的代码 - 获取配置文件内容 输入: - file_path: 文件路径(相对于项目根目录) - start_line: 可选,起始行号 - end_line: 可选,结束行号 - max_lines: 最大返回行数(默认500) 注意: 为避免输出过长,建议指定行范围或使用 RAG 搜索定位代码。""" @property def args_schema(self): return FileReadInput async def _execute( self, file_path: str, start_line: Optional[int] = None, end_line: Optional[int] = None, max_lines: int = 500, **kwargs ) -> ToolResult: """执行文件读取""" try: # 安全检查:防止路径遍历 full_path = os.path.normpath(os.path.join(self.project_root, file_path)) if not full_path.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许访问项目目录外的文件", ) if not os.path.exists(full_path): return ToolResult( success=False, error=f"文件不存在: {file_path}", ) if not os.path.isfile(full_path): return ToolResult( success=False, error=f"不是文件: {file_path}", ) # 检查文件大小 file_size = os.path.getsize(full_path) if file_size > 1024 * 1024: # 1MB return ToolResult( success=False, error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围", ) # 读取文件 with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: lines = f.readlines() total_lines = len(lines) # 处理行范围 if start_line is not None: start_idx = max(0, start_line - 1) else: start_idx = 0 if end_line is not None: end_idx = min(total_lines, end_line) else: end_idx = min(total_lines, start_idx + max_lines) # 截取指定行 selected_lines = lines[start_idx:end_idx] # 添加行号 numbered_lines = [] for i, line in enumerate(selected_lines, start=start_idx + 1): numbered_lines.append(f"{i:4d}| {line.rstrip()}") content = '\n'.join(numbered_lines) # 检测语言 ext = os.path.splitext(file_path)[1].lower() language = { ".py": "python", ".js": "javascript", ".ts": "typescript", ".java": "java", ".go": "go", ".rs": "rust", ".cpp": "cpp", ".c": "c", ".cs": "csharp", ".php": "php", ".rb": "ruby", ".swift": "swift", }.get(ext, "text") output = f"📄 文件: {file_path}\n" output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n" output += f"```{language}\n{content}\n```" if end_idx < total_lines: output += f"\n\n... 还有 {total_lines - end_idx} 行未显示" return ToolResult( success=True, data=output, metadata={ "file_path": file_path, "total_lines": total_lines, "start_line": start_idx + 1, "end_line": end_idx, "language": language, } ) except Exception as e: return ToolResult( success=False, error=f"读取文件失败: {str(e)}", ) class FileSearchInput(BaseModel): """文件搜索输入""" keyword: str = Field(description="搜索关键字或正则表达式") file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js") directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)") case_sensitive: bool = Field(default=False, description="是否区分大小写") max_results: int = Field(default=50, description="最大结果数") is_regex: bool = Field(default=False, description="是否使用正则表达式") class FileSearchTool(AgentTool): """ 文件搜索工具 在项目中搜索包含特定内容的代码 """ # 排除的目录 EXCLUDE_DIRS = { "node_modules", "vendor", "dist", "build", ".git", "__pycache__", ".pytest_cache", "coverage", ".nyc_output", ".vscode", ".idea", ".vs", "target", "venv", "env", } def __init__(self, project_root: str): super().__init__() self.project_root = project_root @property def name(self) -> str: return "search_code" @property def description(self) -> str: return """在项目代码中搜索关键字或模式。 使用场景: - 查找特定函数的所有调用位置 - 搜索特定的 API 使用 - 查找包含特定模式的代码 输入: - keyword: 搜索关键字或正则表达式 - file_pattern: 可选,文件名模式(如 *.py) - directory: 可选,搜索目录 - case_sensitive: 是否区分大小写 - is_regex: 是否使用正则表达式 这是一个快速搜索工具,结果包含匹配行和上下文。""" @property def args_schema(self): return FileSearchInput async def _execute( self, keyword: str, file_pattern: Optional[str] = None, directory: Optional[str] = None, case_sensitive: bool = False, max_results: int = 50, is_regex: bool = False, **kwargs ) -> ToolResult: """执行文件搜索""" try: # 确定搜索目录 if directory: search_dir = os.path.normpath(os.path.join(self.project_root, directory)) if not search_dir.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许搜索项目目录外的内容", ) else: search_dir = self.project_root # 编译搜索模式 flags = 0 if case_sensitive else re.IGNORECASE try: if is_regex: pattern = re.compile(keyword, flags) else: pattern = re.compile(re.escape(keyword), flags) except re.error as e: return ToolResult( success=False, error=f"无效的搜索模式: {e}", ) results = [] files_searched = 0 # 遍历文件 for root, dirs, files in os.walk(search_dir): # 排除目录 dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] for filename in files: # 检查文件模式 if file_pattern and not fnmatch.fnmatch(filename, file_pattern): continue file_path = os.path.join(root, filename) relative_path = os.path.relpath(file_path, self.project_root) try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: lines = f.readlines() files_searched += 1 for i, line in enumerate(lines): if pattern.search(line): # 获取上下文 start = max(0, i - 1) end = min(len(lines), i + 2) context_lines = [] for j in range(start, end): prefix = ">" if j == i else " " context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}") results.append({ "file": relative_path, "line": i + 1, "match": line.strip()[:200], "context": '\n'.join(context_lines), }) if len(results) >= max_results: break if len(results) >= max_results: break except Exception: continue if len(results) >= max_results: break if not results: return ToolResult( success=True, data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件", metadata={"files_searched": files_searched, "matches": 0} ) # 格式化输出 output_parts = [f"🔍 搜索结果: '{keyword}'\n"] output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n") for result in results: output_parts.append(f"\n📄 {result['file']}:{result['line']}") output_parts.append(f"```\n{result['context']}\n```") if len(results) >= max_results: output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "keyword": keyword, "files_searched": files_searched, "matches": len(results), "results": results[:10], # 只在元数据中保留前10个 } ) except Exception as e: return ToolResult( success=False, error=f"搜索失败: {str(e)}", ) class ListFilesInput(BaseModel): """列出文件输入""" directory: str = Field(default=".", description="目录路径(相对于项目根目录)") pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py") recursive: bool = Field(default=False, description="是否递归列出子目录") max_files: int = Field(default=100, description="最大文件数") class ListFilesTool(AgentTool): """ 列出文件工具 列出目录中的文件 """ EXCLUDE_DIRS = { "node_modules", "vendor", "dist", "build", ".git", "__pycache__", ".pytest_cache", "coverage", } def __init__(self, project_root: str): super().__init__() self.project_root = project_root @property def name(self) -> str: return "list_files" @property def description(self) -> str: return """列出目录中的文件。 使用场景: - 了解项目结构 - 查找特定类型的文件 - 浏览目录内容 输入: - directory: 目录路径 - pattern: 可选,文件名模式 - recursive: 是否递归 - max_files: 最大文件数""" @property def args_schema(self): return ListFilesInput async def _execute( self, directory: str = ".", pattern: Optional[str] = None, recursive: bool = False, max_files: int = 100, **kwargs ) -> ToolResult: """执行文件列表""" try: target_dir = os.path.normpath(os.path.join(self.project_root, directory)) if not target_dir.startswith(os.path.normpath(self.project_root)): return ToolResult( success=False, error="安全错误:不允许访问项目目录外的目录", ) if not os.path.exists(target_dir): return ToolResult( success=False, error=f"目录不存在: {directory}", ) files = [] dirs = [] if recursive: for root, dirnames, filenames in os.walk(target_dir): # 排除目录 dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS] for filename in filenames: if pattern and not fnmatch.fnmatch(filename, pattern): continue full_path = os.path.join(root, filename) relative_path = os.path.relpath(full_path, self.project_root) files.append(relative_path) if len(files) >= max_files: break if len(files) >= max_files: break else: for item in os.listdir(target_dir): if item in self.EXCLUDE_DIRS: continue full_path = os.path.join(target_dir, item) relative_path = os.path.relpath(full_path, self.project_root) if os.path.isdir(full_path): dirs.append(relative_path + "/") else: if pattern and not fnmatch.fnmatch(item, pattern): continue files.append(relative_path) if len(files) >= max_files: break # 格式化输出 output_parts = [f"📁 目录: {directory}\n"] if dirs: output_parts.append("目录:") for d in sorted(dirs)[:20]: output_parts.append(f" 📂 {d}") if len(dirs) > 20: output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录") if files: output_parts.append(f"\n文件 ({len(files)}):") for f in sorted(files): output_parts.append(f" 📄 {f}") if len(files) >= max_files: output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)") return ToolResult( success=True, data="\n".join(output_parts), metadata={ "directory": directory, "file_count": len(files), "dir_count": len(dirs), } ) except Exception as e: return ToolResult( success=False, error=f"列出文件失败: {str(e)}", )