diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py index f8b1d49..009cf44 100644 --- a/backend/app/services/agent/tools/file_tool.py +++ b/backend/app/services/agent/tools/file_tool.py @@ -6,6 +6,7 @@ import os import re import fnmatch +import asyncio from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field @@ -44,7 +45,37 @@ class FileReadTool(AgentTool): self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None - + + @staticmethod + def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple: + """同步读取文件指定行范围(用于 asyncio.to_thread)""" + selected_lines = [] + total_lines = 0 + file_size = os.path.getsize(file_path) + + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + for i, line in enumerate(f): + total_lines = i + 1 + if i >= start_idx and i < end_idx: + selected_lines.append(line) + elif i >= end_idx: + if i < end_idx + 1000: + continue + else: + remaining_bytes = file_size - f.tell() + avg_line_size = f.tell() / (i + 1) + estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0 + total_lines = i + 1 + estimated_remaining_lines + break + + return selected_lines, total_lines + + @staticmethod + def _read_all_lines_sync(file_path: str) -> List[str]: + """同步读取文件所有行(用于 asyncio.to_thread)""" + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.readlines() + @property def name(self) -> str: return "read_file" @@ -136,51 +167,34 @@ class FileReadTool(AgentTool): # 🔥 对于大文件,使用流式读取指定行范围 if is_large_file and (start_line is not None or end_line is not None): - # 流式读取,避免一次性加载整个文件 - selected_lines = [] - total_lines = 0 - # 计算实际的起始和结束行 start_idx = max(0, (start_line or 1) - 1) end_idx = end_line if end_line else start_idx + max_lines - - with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: - for i, line in enumerate(f): - total_lines = i + 1 - if i >= start_idx and i < end_idx: - selected_lines.append(line) - elif i >= end_idx: - # 继续计数以获取总行数,但限制读取量 - if i < end_idx + 1000: # 最多再读1000行来估算总行数 - continue - else: - # 估算剩余行数 - remaining_bytes = file_size - f.tell() - avg_line_size = f.tell() / (i + 1) - estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0 - total_lines = i + 1 + estimated_remaining_lines - break - + + # 异步读取文件,避免阻塞事件循环 + selected_lines, total_lines = await asyncio.to_thread( + self._read_file_lines_sync, full_path, start_idx, end_idx + ) + # 更新实际的结束索引 end_idx = min(end_idx, start_idx + len(selected_lines)) else: - # 正常读取小文件 - with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: - lines = f.readlines() - + # 异步读取小文件,避免阻塞事件循环 + lines = await asyncio.to_thread(self._read_all_lines_sync, full_path) + total_lines = len(lines) - + # 处理行范围 if start_line is not None: start_idx = max(0, start_line - 1) else: start_idx = 0 - + if end_line is not None: end_idx = min(total_lines, end_line) else: end_idx = min(total_lines, start_idx + max_lines) - + # 截取指定行 selected_lines = lines[start_idx:end_idx] @@ -259,7 +273,7 @@ class FileSearchTool(AgentTool): self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None - + # 从 exclude_patterns 中提取目录排除 self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS) for pattern in self.exclude_patterns: @@ -267,7 +281,13 @@ class FileSearchTool(AgentTool): self.exclude_dirs.add(pattern[:-3]) elif "/" not in pattern and "*" not in pattern: self.exclude_dirs.add(pattern) - + + @staticmethod + def _read_file_lines_sync(file_path: str) -> List[str]: + """同步读取文件所有行(用于 asyncio.to_thread)""" + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.readlines() + @property def name(self) -> str: return "search_code" @@ -360,11 +380,13 @@ class FileSearchTool(AgentTool): continue try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - lines = f.readlines() - + # 异步读取文件,避免阻塞事件循环 + lines = await asyncio.to_thread( + self._read_file_lines_sync, file_path + ) + files_searched += 1 - + for i, line in enumerate(lines): if pattern.search(line): # 获取上下文 diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index d82ba68..bdb15ed 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -739,6 +739,20 @@ class CodeIndexer: self._needs_rebuild = False self._rebuild_reason = "" + @staticmethod + def _read_file_sync(file_path: str) -> str: + """ + 同步读取文件内容(用于 asyncio.to_thread 包装) + + Args: + file_path: 文件路径 + + Returns: + 文件内容 + """ + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]: """ 初始化索引器,检测是否需要重建索引 @@ -916,8 +930,10 @@ class CodeIndexer: try: relative_path = os.path.relpath(file_path, directory) - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) if not content.strip(): progress.processed_files += 1 @@ -932,8 +948,8 @@ class CodeIndexer: if len(content) > 500000: content = content[:500000] - # 分块 - chunks = self.splitter.split_file(content, relative_path) + # 异步分块,避免 Tree-sitter 解析阻塞事件循环 + chunks = await self.splitter.split_file_async(content, relative_path) # 为每个 chunk 添加 file_hash for chunk in chunks: @@ -1018,8 +1034,10 @@ class CodeIndexer: for relative_path in files_to_check: file_path = current_file_map[relative_path] try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) current_hash = hashlib.md5(content.encode()).hexdigest() if current_hash != indexed_file_hashes.get(relative_path): files_to_update.add(relative_path) @@ -1055,8 +1073,10 @@ class CodeIndexer: is_update = relative_path in files_to_update try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) if not content.strip(): progress.processed_files += 1 @@ -1075,8 +1095,8 @@ class CodeIndexer: if len(content) > 500000: content = content[:500000] - # 分块 - chunks = self.splitter.split_file(content, relative_path) + # 异步分块,避免 Tree-sitter 解析阻塞事件循环 + chunks = await self.splitter.split_file_async(content, relative_path) # 为每个 chunk 添加 file_hash for chunk in chunks: diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index 65d4f16..5c350b0 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -4,6 +4,7 @@ """ import re +import asyncio import hashlib import logging from typing import List, Dict, Any, Optional, Tuple, Set @@ -230,21 +231,30 @@ class TreeSitterParser: return False def parse(self, code: str, language: str) -> Optional[Any]: - """解析代码返回 AST""" + """解析代码返回 AST(同步方法)""" if not self._ensure_initialized(language): return None - + parser = self._parsers.get(language) if not parser: return None - + try: tree = parser.parse(code.encode()) return tree except Exception as e: logger.warning(f"Failed to parse code: {e}") return None - + + async def parse_async(self, code: str, language: str) -> Optional[Any]: + """ + 异步解析代码返回 AST + + 将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行, + 避免阻塞事件循环 + """ + return await asyncio.to_thread(self.parse, code, language) + def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]: """从 AST 提取定义""" if tree is None: @@ -449,9 +459,31 @@ class CodeSplitter: except Exception as e: logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") chunks = self._split_by_lines(content, file_path, language) - + return chunks - + + async def split_file_async( + self, + content: str, + file_path: str, + language: Optional[str] = None + ) -> List[CodeChunk]: + """ + 异步分割单个文件 + + 将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行, + 避免阻塞事件循环。 + + Args: + content: 文件内容 + file_path: 文件路径 + language: 编程语言(可选) + + Returns: + 代码块列表 + """ + return await asyncio.to_thread(self.split_file, content, file_path, language) + def _split_by_ast( self, content: str,