feat: 为文件操作和代码解析添加异步支持
将同步的文件读取、代码解析和分块操作改为异步实现,使用 asyncio.to_thread 将 CPU 密集型操作放到线程池执行,避免阻塞事件循环。主要修改包括: - 在 TreeSitterParser 和 CodeSplitter 中添加异步解析方法 - 修改 CodeIndexer 使用异步文件读取和分块 - 为 FileReadTool 和 FileSearchTool 添加异步文件读取支持
This commit is contained in:
parent
fdbec80da5
commit
c7632afdab
|
|
@ -6,6 +6,7 @@
|
|||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -44,7 +45,37 @@ class FileReadTool(AgentTool):
|
|||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
|
||||
"""同步读取文件指定行范围(用于 asyncio.to_thread)"""
|
||||
selected_lines = []
|
||||
total_lines = 0
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for i, line in enumerate(f):
|
||||
total_lines = i + 1
|
||||
if i >= start_idx and i < end_idx:
|
||||
selected_lines.append(line)
|
||||
elif i >= end_idx:
|
||||
if i < end_idx + 1000:
|
||||
continue
|
||||
else:
|
||||
remaining_bytes = file_size - f.tell()
|
||||
avg_line_size = f.tell() / (i + 1)
|
||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||
total_lines = i + 1 + estimated_remaining_lines
|
||||
break
|
||||
|
||||
return selected_lines, total_lines
|
||||
|
||||
@staticmethod
|
||||
def _read_all_lines_sync(file_path: str) -> List[str]:
|
||||
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.readlines()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
|
@ -136,51 +167,34 @@ class FileReadTool(AgentTool):
|
|||
|
||||
# 🔥 对于大文件,使用流式读取指定行范围
|
||||
if is_large_file and (start_line is not None or end_line is not None):
|
||||
# 流式读取,避免一次性加载整个文件
|
||||
selected_lines = []
|
||||
total_lines = 0
|
||||
|
||||
# 计算实际的起始和结束行
|
||||
start_idx = max(0, (start_line or 1) - 1)
|
||||
end_idx = end_line if end_line else start_idx + max_lines
|
||||
|
||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for i, line in enumerate(f):
|
||||
total_lines = i + 1
|
||||
if i >= start_idx and i < end_idx:
|
||||
selected_lines.append(line)
|
||||
elif i >= end_idx:
|
||||
# 继续计数以获取总行数,但限制读取量
|
||||
if i < end_idx + 1000: # 最多再读1000行来估算总行数
|
||||
continue
|
||||
else:
|
||||
# 估算剩余行数
|
||||
remaining_bytes = file_size - f.tell()
|
||||
avg_line_size = f.tell() / (i + 1)
|
||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||
total_lines = i + 1 + estimated_remaining_lines
|
||||
break
|
||||
|
||||
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
selected_lines, total_lines = await asyncio.to_thread(
|
||||
self._read_file_lines_sync, full_path, start_idx, end_idx
|
||||
)
|
||||
|
||||
# 更新实际的结束索引
|
||||
end_idx = min(end_idx, start_idx + len(selected_lines))
|
||||
else:
|
||||
# 正常读取小文件
|
||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 异步读取小文件,避免阻塞事件循环
|
||||
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
|
||||
|
||||
total_lines = len(lines)
|
||||
|
||||
|
||||
# 处理行范围
|
||||
if start_line is not None:
|
||||
start_idx = max(0, start_line - 1)
|
||||
else:
|
||||
start_idx = 0
|
||||
|
||||
|
||||
if end_line is not None:
|
||||
end_idx = min(total_lines, end_line)
|
||||
else:
|
||||
end_idx = min(total_lines, start_idx + max_lines)
|
||||
|
||||
|
||||
# 截取指定行
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
|
||||
|
|
@ -259,7 +273,7 @@ class FileSearchTool(AgentTool):
|
|||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
|
||||
# 从 exclude_patterns 中提取目录排除
|
||||
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
|
||||
for pattern in self.exclude_patterns:
|
||||
|
|
@ -267,7 +281,13 @@ class FileSearchTool(AgentTool):
|
|||
self.exclude_dirs.add(pattern[:-3])
|
||||
elif "/" not in pattern and "*" not in pattern:
|
||||
self.exclude_dirs.add(pattern)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _read_file_lines_sync(file_path: str) -> List[str]:
|
||||
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.readlines()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_code"
|
||||
|
|
@ -360,11 +380,13 @@ class FileSearchTool(AgentTool):
|
|||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
lines = await asyncio.to_thread(
|
||||
self._read_file_lines_sync, file_path
|
||||
)
|
||||
|
||||
files_searched += 1
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if pattern.search(line):
|
||||
# 获取上下文
|
||||
|
|
|
|||
|
|
@ -739,6 +739,20 @@ class CodeIndexer:
|
|||
self._needs_rebuild = False
|
||||
self._rebuild_reason = ""
|
||||
|
||||
@staticmethod
|
||||
def _read_file_sync(file_path: str) -> str:
|
||||
"""
|
||||
同步读取文件内容(用于 asyncio.to_thread 包装)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件内容
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.read()
|
||||
|
||||
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
初始化索引器,检测是否需要重建索引
|
||||
|
|
@ -916,8 +930,10 @@ class CodeIndexer:
|
|||
try:
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
|
|
@ -932,8 +948,8 @@ class CodeIndexer:
|
|||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 分块
|
||||
chunks = self.splitter.split_file(content, relative_path)
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
|
|
@ -1018,8 +1034,10 @@ class CodeIndexer:
|
|||
for relative_path in files_to_check:
|
||||
file_path = current_file_map[relative_path]
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
current_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
if current_hash != indexed_file_hashes.get(relative_path):
|
||||
files_to_update.add(relative_path)
|
||||
|
|
@ -1055,8 +1073,10 @@ class CodeIndexer:
|
|||
is_update = relative_path in files_to_update
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
|
|
@ -1075,8 +1095,8 @@ class CodeIndexer:
|
|||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 分块
|
||||
chunks = self.splitter.split_file(content, relative_path)
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
|
|
@ -230,21 +231,30 @@ class TreeSitterParser:
|
|||
return False
|
||||
|
||||
def parse(self, code: str, language: str) -> Optional[Any]:
|
||||
"""解析代码返回 AST"""
|
||||
"""解析代码返回 AST(同步方法)"""
|
||||
if not self._ensure_initialized(language):
|
||||
return None
|
||||
|
||||
|
||||
parser = self._parsers.get(language)
|
||||
if not parser:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
tree = parser.parse(code.encode())
|
||||
return tree
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse code: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def parse_async(self, code: str, language: str) -> Optional[Any]:
|
||||
"""
|
||||
异步解析代码返回 AST
|
||||
|
||||
将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
|
||||
避免阻塞事件循环
|
||||
"""
|
||||
return await asyncio.to_thread(self.parse, code, language)
|
||||
|
||||
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
||||
"""从 AST 提取定义"""
|
||||
if tree is None:
|
||||
|
|
@ -449,9 +459,31 @@ class CodeSplitter:
|
|||
except Exception as e:
|
||||
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
||||
chunks = self._split_by_lines(content, file_path, language)
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def split_file_async(
|
||||
self,
|
||||
content: str,
|
||||
file_path: str,
|
||||
language: Optional[str] = None
|
||||
) -> List[CodeChunk]:
|
||||
"""
|
||||
异步分割单个文件
|
||||
|
||||
将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
|
||||
避免阻塞事件循环。
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
file_path: 文件路径
|
||||
language: 编程语言(可选)
|
||||
|
||||
Returns:
|
||||
代码块列表
|
||||
"""
|
||||
return await asyncio.to_thread(self.split_file, content, file_path, language)
|
||||
|
||||
def _split_by_ast(
|
||||
self,
|
||||
content: str,
|
||||
|
|
|
|||
Loading…
Reference in New Issue