feat: 为文件操作和代码解析添加异步支持

将同步的文件读取、代码解析和分块操作改为异步实现,使用 asyncio.to_thread 将 CPU 密集型操作放到线程池执行,避免阻塞事件循环。主要修改包括:
- 在 TreeSitterParser 和 CodeSplitter 中添加异步解析方法
- 修改 CodeIndexer 使用异步文件读取和分块
- 为 FileReadTool 和 FileSearchTool 添加异步文件读取支持
This commit is contained in:
lintsinghua 2025-12-25 17:20:42 +08:00
parent fdbec80da5
commit c7632afdab
3 changed files with 126 additions and 52 deletions

View File

@ -6,6 +6,7 @@
import os import os
import re import re
import fnmatch import fnmatch
import asyncio
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -44,7 +45,37 @@ class FileReadTool(AgentTool):
self.project_root = project_root self.project_root = project_root
self.exclude_patterns = exclude_patterns or [] self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None self.target_files = set(target_files) if target_files else None
@staticmethod
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
"""同步读取文件指定行范围(用于 asyncio.to_thread"""
selected_lines = []
total_lines = 0
file_size = os.path.getsize(file_path)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
if i < end_idx + 1000:
continue
else:
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
return selected_lines, total_lines
@staticmethod
def _read_all_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property @property
def name(self) -> str: def name(self) -> str:
return "read_file" return "read_file"
@ -136,51 +167,34 @@ class FileReadTool(AgentTool):
# 🔥 对于大文件,使用流式读取指定行范围 # 🔥 对于大文件,使用流式读取指定行范围
if is_large_file and (start_line is not None or end_line is not None): if is_large_file and (start_line is not None or end_line is not None):
# 流式读取,避免一次性加载整个文件
selected_lines = []
total_lines = 0
# 计算实际的起始和结束行 # 计算实际的起始和结束行
start_idx = max(0, (start_line or 1) - 1) start_idx = max(0, (start_line or 1) - 1)
end_idx = end_line if end_line else start_idx + max_lines end_idx = end_line if end_line else start_idx + max_lines
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: # 异步读取文件,避免阻塞事件循环
for i, line in enumerate(f): selected_lines, total_lines = await asyncio.to_thread(
total_lines = i + 1 self._read_file_lines_sync, full_path, start_idx, end_idx
if i >= start_idx and i < end_idx: )
selected_lines.append(line)
elif i >= end_idx:
# 继续计数以获取总行数,但限制读取量
if i < end_idx + 1000: # 最多再读1000行来估算总行数
continue
else:
# 估算剩余行数
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
# 更新实际的结束索引 # 更新实际的结束索引
end_idx = min(end_idx, start_idx + len(selected_lines)) end_idx = min(end_idx, start_idx + len(selected_lines))
else: else:
# 正常读取小文件 # 异步读取小文件,避免阻塞事件循环
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
lines = f.readlines()
total_lines = len(lines) total_lines = len(lines)
# 处理行范围 # 处理行范围
if start_line is not None: if start_line is not None:
start_idx = max(0, start_line - 1) start_idx = max(0, start_line - 1)
else: else:
start_idx = 0 start_idx = 0
if end_line is not None: if end_line is not None:
end_idx = min(total_lines, end_line) end_idx = min(total_lines, end_line)
else: else:
end_idx = min(total_lines, start_idx + max_lines) end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行 # 截取指定行
selected_lines = lines[start_idx:end_idx] selected_lines = lines[start_idx:end_idx]
@ -259,7 +273,7 @@ class FileSearchTool(AgentTool):
self.project_root = project_root self.project_root = project_root
self.exclude_patterns = exclude_patterns or [] self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除 # 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS) self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns: for pattern in self.exclude_patterns:
@ -267,7 +281,13 @@ class FileSearchTool(AgentTool):
self.exclude_dirs.add(pattern[:-3]) self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern: elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern) self.exclude_dirs.add(pattern)
@staticmethod
def _read_file_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property @property
def name(self) -> str: def name(self) -> str:
return "search_code" return "search_code"
@ -360,11 +380,13 @@ class FileSearchTool(AgentTool):
continue continue
try: try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: # 异步读取文件,避免阻塞事件循环
lines = f.readlines() lines = await asyncio.to_thread(
self._read_file_lines_sync, file_path
)
files_searched += 1 files_searched += 1
for i, line in enumerate(lines): for i, line in enumerate(lines):
if pattern.search(line): if pattern.search(line):
# 获取上下文 # 获取上下文

View File

@ -739,6 +739,20 @@ class CodeIndexer:
self._needs_rebuild = False self._needs_rebuild = False
self._rebuild_reason = "" self._rebuild_reason = ""
@staticmethod
def _read_file_sync(file_path: str) -> str:
"""
同步读取文件内容用于 asyncio.to_thread 包装
Args:
file_path: 文件路径
Returns:
文件内容
"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]: async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
""" """
初始化索引器检测是否需要重建索引 初始化索引器检测是否需要重建索引
@ -916,8 +930,10 @@ class CodeIndexer:
try: try:
relative_path = os.path.relpath(file_path, directory) relative_path = os.path.relpath(file_path, directory)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: # 异步读取文件,避免阻塞事件循环
content = f.read() content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip(): if not content.strip():
progress.processed_files += 1 progress.processed_files += 1
@ -932,8 +948,8 @@ class CodeIndexer:
if len(content) > 500000: if len(content) > 500000:
content = content[:500000] content = content[:500000]
# 分块 # 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = self.splitter.split_file(content, relative_path) chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash # 为每个 chunk 添加 file_hash
for chunk in chunks: for chunk in chunks:
@ -1018,8 +1034,10 @@ class CodeIndexer:
for relative_path in files_to_check: for relative_path in files_to_check:
file_path = current_file_map[relative_path] file_path = current_file_map[relative_path]
try: try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: # 异步读取文件,避免阻塞事件循环
content = f.read() content = await asyncio.to_thread(
self._read_file_sync, file_path
)
current_hash = hashlib.md5(content.encode()).hexdigest() current_hash = hashlib.md5(content.encode()).hexdigest()
if current_hash != indexed_file_hashes.get(relative_path): if current_hash != indexed_file_hashes.get(relative_path):
files_to_update.add(relative_path) files_to_update.add(relative_path)
@ -1055,8 +1073,10 @@ class CodeIndexer:
is_update = relative_path in files_to_update is_update = relative_path in files_to_update
try: try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: # 异步读取文件,避免阻塞事件循环
content = f.read() content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip(): if not content.strip():
progress.processed_files += 1 progress.processed_files += 1
@ -1075,8 +1095,8 @@ class CodeIndexer:
if len(content) > 500000: if len(content) > 500000:
content = content[:500000] content = content[:500000]
# 分块 # 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = self.splitter.split_file(content, relative_path) chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash # 为每个 chunk 添加 file_hash
for chunk in chunks: for chunk in chunks:

View File

@ -4,6 +4,7 @@
""" """
import re import re
import asyncio
import hashlib import hashlib
import logging import logging
from typing import List, Dict, Any, Optional, Tuple, Set from typing import List, Dict, Any, Optional, Tuple, Set
@ -230,21 +231,30 @@ class TreeSitterParser:
return False return False
def parse(self, code: str, language: str) -> Optional[Any]: def parse(self, code: str, language: str) -> Optional[Any]:
"""解析代码返回 AST""" """解析代码返回 AST(同步方法)"""
if not self._ensure_initialized(language): if not self._ensure_initialized(language):
return None return None
parser = self._parsers.get(language) parser = self._parsers.get(language)
if not parser: if not parser:
return None return None
try: try:
tree = parser.parse(code.encode()) tree = parser.parse(code.encode())
return tree return tree
except Exception as e: except Exception as e:
logger.warning(f"Failed to parse code: {e}") logger.warning(f"Failed to parse code: {e}")
return None return None
async def parse_async(self, code: str, language: str) -> Optional[Any]:
"""
异步解析代码返回 AST
CPU 密集型的 Tree-sitter 解析操作放到线程池中执行
避免阻塞事件循环
"""
return await asyncio.to_thread(self.parse, code, language)
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]: def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义""" """从 AST 提取定义"""
if tree is None: if tree is None:
@ -449,9 +459,31 @@ class CodeSplitter:
except Exception as e: except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language) chunks = self._split_by_lines(content, file_path, language)
return chunks return chunks
async def split_file_async(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
异步分割单个文件
CPU 密集型的分块操作包括 Tree-sitter 解析放到线程池中执行
避免阻塞事件循环
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言可选
Returns:
代码块列表
"""
return await asyncio.to_thread(self.split_file, content, file_path, language)
def _split_by_ast( def _split_by_ast(
self, self,
content: str, content: str,