feat: 为文件操作和代码解析添加异步支持

将同步的文件读取、代码解析和分块操作改为异步实现,使用 asyncio.to_thread 将 CPU 密集型操作放到线程池执行,避免阻塞事件循环。主要修改包括:
- 在 TreeSitterParser 和 CodeSplitter 中添加异步解析方法
- 修改 CodeIndexer 使用异步文件读取和分块
- 为 FileReadTool 和 FileSearchTool 添加异步文件读取支持
This commit is contained in:
lintsinghua 2025-12-25 17:20:42 +08:00
parent fdbec80da5
commit c7632afdab
3 changed files with 126 additions and 52 deletions

View File

@ -6,6 +6,7 @@
import os
import re
import fnmatch
import asyncio
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
@ -44,7 +45,37 @@ class FileReadTool(AgentTool):
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
@staticmethod
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
"""同步读取文件指定行范围(用于 asyncio.to_thread"""
selected_lines = []
total_lines = 0
file_size = os.path.getsize(file_path)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
if i < end_idx + 1000:
continue
else:
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
return selected_lines, total_lines
@staticmethod
def _read_all_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "read_file"
@ -136,51 +167,34 @@ class FileReadTool(AgentTool):
# 🔥 对于大文件,使用流式读取指定行范围
if is_large_file and (start_line is not None or end_line is not None):
# 流式读取,避免一次性加载整个文件
selected_lines = []
total_lines = 0
# 计算实际的起始和结束行
start_idx = max(0, (start_line or 1) - 1)
end_idx = end_line if end_line else start_idx + max_lines
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
# 继续计数以获取总行数,但限制读取量
if i < end_idx + 1000: # 最多再读1000行来估算总行数
continue
else:
# 估算剩余行数
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
# 异步读取文件,避免阻塞事件循环
selected_lines, total_lines = await asyncio.to_thread(
self._read_file_lines_sync, full_path, start_idx, end_idx
)
# 更新实际的结束索引
end_idx = min(end_idx, start_idx + len(selected_lines))
else:
# 正常读取小文件
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 异步读取小文件,避免阻塞事件循环
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
total_lines = len(lines)
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
# 截取指定行
selected_lines = lines[start_idx:end_idx]
@ -259,7 +273,7 @@ class FileSearchTool(AgentTool):
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
@ -267,7 +281,13 @@ class FileSearchTool(AgentTool):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@staticmethod
def _read_file_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "search_code"
@ -360,11 +380,13 @@ class FileSearchTool(AgentTool):
continue
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 异步读取文件,避免阻塞事件循环
lines = await asyncio.to_thread(
self._read_file_lines_sync, file_path
)
files_searched += 1
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文

View File

@ -739,6 +739,20 @@ class CodeIndexer:
self._needs_rebuild = False
self._rebuild_reason = ""
@staticmethod
def _read_file_sync(file_path: str) -> str:
"""
同步读取文件内容用于 asyncio.to_thread 包装
Args:
file_path: 文件路径
Returns:
文件内容
"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
"""
初始化索引器检测是否需要重建索引
@ -916,8 +930,10 @@ class CodeIndexer:
try:
relative_path = os.path.relpath(file_path, directory)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip():
progress.processed_files += 1
@ -932,8 +948,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
@ -1018,8 +1034,10 @@ class CodeIndexer:
for relative_path in files_to_check:
file_path = current_file_map[relative_path]
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
current_hash = hashlib.md5(content.encode()).hexdigest()
if current_hash != indexed_file_hashes.get(relative_path):
files_to_update.add(relative_path)
@ -1055,8 +1073,10 @@ class CodeIndexer:
is_update = relative_path in files_to_update
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip():
progress.processed_files += 1
@ -1075,8 +1095,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:

View File

@ -4,6 +4,7 @@
"""
import re
import asyncio
import hashlib
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
@ -230,21 +231,30 @@ class TreeSitterParser:
return False
def parse(self, code: str, language: str) -> Optional[Any]:
"""解析代码返回 AST"""
"""解析代码返回 AST(同步方法)"""
if not self._ensure_initialized(language):
return None
parser = self._parsers.get(language)
if not parser:
return None
try:
tree = parser.parse(code.encode())
return tree
except Exception as e:
logger.warning(f"Failed to parse code: {e}")
return None
async def parse_async(self, code: str, language: str) -> Optional[Any]:
"""
异步解析代码返回 AST
CPU 密集型的 Tree-sitter 解析操作放到线程池中执行
避免阻塞事件循环
"""
return await asyncio.to_thread(self.parse, code, language)
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义"""
if tree is None:
@ -449,9 +459,31 @@ class CodeSplitter:
except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language)
return chunks
async def split_file_async(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
异步分割单个文件
CPU 密集型的分块操作包括 Tree-sitter 解析放到线程池中执行
避免阻塞事件循环
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言可选
Returns:
代码块列表
"""
return await asyncio.to_thread(self.split_file, content, file_path, language)
def _split_by_ast(
self,
content: str,