feat: 为文件操作和代码解析添加异步支持
将同步的文件读取、代码解析和分块操作改为异步实现,使用 asyncio.to_thread 将 CPU 密集型操作放到线程池执行,避免阻塞事件循环。主要修改包括: - 在 TreeSitterParser 和 CodeSplitter 中添加异步解析方法 - 修改 CodeIndexer 使用异步文件读取和分块 - 为 FileReadTool 和 FileSearchTool 添加异步文件读取支持
This commit is contained in:
parent
fdbec80da5
commit
c7632afdab
|
|
@ -6,6 +6,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import asyncio
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -44,7 +45,37 @@ class FileReadTool(AgentTool):
|
||||||
self.project_root = project_root
|
self.project_root = project_root
|
||||||
self.exclude_patterns = exclude_patterns or []
|
self.exclude_patterns = exclude_patterns or []
|
||||||
self.target_files = set(target_files) if target_files else None
|
self.target_files = set(target_files) if target_files else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
|
||||||
|
"""同步读取文件指定行范围(用于 asyncio.to_thread)"""
|
||||||
|
selected_lines = []
|
||||||
|
total_lines = 0
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
total_lines = i + 1
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
selected_lines.append(line)
|
||||||
|
elif i >= end_idx:
|
||||||
|
if i < end_idx + 1000:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
remaining_bytes = file_size - f.tell()
|
||||||
|
avg_line_size = f.tell() / (i + 1)
|
||||||
|
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||||
|
total_lines = i + 1 + estimated_remaining_lines
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected_lines, total_lines
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_all_lines_sync(file_path: str) -> List[str]:
|
||||||
|
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "read_file"
|
return "read_file"
|
||||||
|
|
@ -136,51 +167,34 @@ class FileReadTool(AgentTool):
|
||||||
|
|
||||||
# 🔥 对于大文件,使用流式读取指定行范围
|
# 🔥 对于大文件,使用流式读取指定行范围
|
||||||
if is_large_file and (start_line is not None or end_line is not None):
|
if is_large_file and (start_line is not None or end_line is not None):
|
||||||
# 流式读取,避免一次性加载整个文件
|
|
||||||
selected_lines = []
|
|
||||||
total_lines = 0
|
|
||||||
|
|
||||||
# 计算实际的起始和结束行
|
# 计算实际的起始和结束行
|
||||||
start_idx = max(0, (start_line or 1) - 1)
|
start_idx = max(0, (start_line or 1) - 1)
|
||||||
end_idx = end_line if end_line else start_idx + max_lines
|
end_idx = end_line if end_line else start_idx + max_lines
|
||||||
|
|
||||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
for i, line in enumerate(f):
|
selected_lines, total_lines = await asyncio.to_thread(
|
||||||
total_lines = i + 1
|
self._read_file_lines_sync, full_path, start_idx, end_idx
|
||||||
if i >= start_idx and i < end_idx:
|
)
|
||||||
selected_lines.append(line)
|
|
||||||
elif i >= end_idx:
|
|
||||||
# 继续计数以获取总行数,但限制读取量
|
|
||||||
if i < end_idx + 1000: # 最多再读1000行来估算总行数
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# 估算剩余行数
|
|
||||||
remaining_bytes = file_size - f.tell()
|
|
||||||
avg_line_size = f.tell() / (i + 1)
|
|
||||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
|
||||||
total_lines = i + 1 + estimated_remaining_lines
|
|
||||||
break
|
|
||||||
|
|
||||||
# 更新实际的结束索引
|
# 更新实际的结束索引
|
||||||
end_idx = min(end_idx, start_idx + len(selected_lines))
|
end_idx = min(end_idx, start_idx + len(selected_lines))
|
||||||
else:
|
else:
|
||||||
# 正常读取小文件
|
# 异步读取小文件,避免阻塞事件循环
|
||||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
total_lines = len(lines)
|
total_lines = len(lines)
|
||||||
|
|
||||||
# 处理行范围
|
# 处理行范围
|
||||||
if start_line is not None:
|
if start_line is not None:
|
||||||
start_idx = max(0, start_line - 1)
|
start_idx = max(0, start_line - 1)
|
||||||
else:
|
else:
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
|
||||||
if end_line is not None:
|
if end_line is not None:
|
||||||
end_idx = min(total_lines, end_line)
|
end_idx = min(total_lines, end_line)
|
||||||
else:
|
else:
|
||||||
end_idx = min(total_lines, start_idx + max_lines)
|
end_idx = min(total_lines, start_idx + max_lines)
|
||||||
|
|
||||||
# 截取指定行
|
# 截取指定行
|
||||||
selected_lines = lines[start_idx:end_idx]
|
selected_lines = lines[start_idx:end_idx]
|
||||||
|
|
||||||
|
|
@ -259,7 +273,7 @@ class FileSearchTool(AgentTool):
|
||||||
self.project_root = project_root
|
self.project_root = project_root
|
||||||
self.exclude_patterns = exclude_patterns or []
|
self.exclude_patterns = exclude_patterns or []
|
||||||
self.target_files = set(target_files) if target_files else None
|
self.target_files = set(target_files) if target_files else None
|
||||||
|
|
||||||
# 从 exclude_patterns 中提取目录排除
|
# 从 exclude_patterns 中提取目录排除
|
||||||
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
|
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
|
||||||
for pattern in self.exclude_patterns:
|
for pattern in self.exclude_patterns:
|
||||||
|
|
@ -267,7 +281,13 @@ class FileSearchTool(AgentTool):
|
||||||
self.exclude_dirs.add(pattern[:-3])
|
self.exclude_dirs.add(pattern[:-3])
|
||||||
elif "/" not in pattern and "*" not in pattern:
|
elif "/" not in pattern and "*" not in pattern:
|
||||||
self.exclude_dirs.add(pattern)
|
self.exclude_dirs.add(pattern)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_lines_sync(file_path: str) -> List[str]:
|
||||||
|
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "search_code"
|
return "search_code"
|
||||||
|
|
@ -360,11 +380,13 @@ class FileSearchTool(AgentTool):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
lines = f.readlines()
|
lines = await asyncio.to_thread(
|
||||||
|
self._read_file_lines_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
files_searched += 1
|
files_searched += 1
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
if pattern.search(line):
|
if pattern.search(line):
|
||||||
# 获取上下文
|
# 获取上下文
|
||||||
|
|
|
||||||
|
|
@ -739,6 +739,20 @@ class CodeIndexer:
|
||||||
self._needs_rebuild = False
|
self._needs_rebuild = False
|
||||||
self._rebuild_reason = ""
|
self._rebuild_reason = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_sync(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
同步读取文件内容(用于 asyncio.to_thread 包装)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件内容
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
初始化索引器,检测是否需要重建索引
|
初始化索引器,检测是否需要重建索引
|
||||||
|
|
@ -916,8 +930,10 @@ class CodeIndexer:
|
||||||
try:
|
try:
|
||||||
relative_path = os.path.relpath(file_path, directory)
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not content.strip():
|
if not content.strip():
|
||||||
progress.processed_files += 1
|
progress.processed_files += 1
|
||||||
|
|
@ -932,8 +948,8 @@ class CodeIndexer:
|
||||||
if len(content) > 500000:
|
if len(content) > 500000:
|
||||||
content = content[:500000]
|
content = content[:500000]
|
||||||
|
|
||||||
# 分块
|
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||||
chunks = self.splitter.split_file(content, relative_path)
|
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||||
|
|
||||||
# 为每个 chunk 添加 file_hash
|
# 为每个 chunk 添加 file_hash
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
|
@ -1018,8 +1034,10 @@ class CodeIndexer:
|
||||||
for relative_path in files_to_check:
|
for relative_path in files_to_check:
|
||||||
file_path = current_file_map[relative_path]
|
file_path = current_file_map[relative_path]
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
current_hash = hashlib.md5(content.encode()).hexdigest()
|
current_hash = hashlib.md5(content.encode()).hexdigest()
|
||||||
if current_hash != indexed_file_hashes.get(relative_path):
|
if current_hash != indexed_file_hashes.get(relative_path):
|
||||||
files_to_update.add(relative_path)
|
files_to_update.add(relative_path)
|
||||||
|
|
@ -1055,8 +1073,10 @@ class CodeIndexer:
|
||||||
is_update = relative_path in files_to_update
|
is_update = relative_path in files_to_update
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not content.strip():
|
if not content.strip():
|
||||||
progress.processed_files += 1
|
progress.processed_files += 1
|
||||||
|
|
@ -1075,8 +1095,8 @@ class CodeIndexer:
|
||||||
if len(content) > 500000:
|
if len(content) > 500000:
|
||||||
content = content[:500000]
|
content = content[:500000]
|
||||||
|
|
||||||
# 分块
|
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||||
chunks = self.splitter.split_file(content, relative_path)
|
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||||
|
|
||||||
# 为每个 chunk 添加 file_hash
|
# 为每个 chunk 添加 file_hash
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||||
|
|
@ -230,21 +231,30 @@ class TreeSitterParser:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def parse(self, code: str, language: str) -> Optional[Any]:
|
def parse(self, code: str, language: str) -> Optional[Any]:
|
||||||
"""解析代码返回 AST"""
|
"""解析代码返回 AST(同步方法)"""
|
||||||
if not self._ensure_initialized(language):
|
if not self._ensure_initialized(language):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
parser = self._parsers.get(language)
|
parser = self._parsers.get(language)
|
||||||
if not parser:
|
if not parser:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tree = parser.parse(code.encode())
|
tree = parser.parse(code.encode())
|
||||||
return tree
|
return tree
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to parse code: {e}")
|
logger.warning(f"Failed to parse code: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def parse_async(self, code: str, language: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
异步解析代码返回 AST
|
||||||
|
|
||||||
|
将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
|
||||||
|
避免阻塞事件循环
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.parse, code, language)
|
||||||
|
|
||||||
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
||||||
"""从 AST 提取定义"""
|
"""从 AST 提取定义"""
|
||||||
if tree is None:
|
if tree is None:
|
||||||
|
|
@ -449,9 +459,31 @@ class CodeSplitter:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
||||||
chunks = self._split_by_lines(content, file_path, language)
|
chunks = self._split_by_lines(content, file_path, language)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
async def split_file_async(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: Optional[str] = None
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""
|
||||||
|
异步分割单个文件
|
||||||
|
|
||||||
|
将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
|
||||||
|
避免阻塞事件循环。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文件内容
|
||||||
|
file_path: 文件路径
|
||||||
|
language: 编程语言(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
代码块列表
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.split_file, content, file_path, language)
|
||||||
|
|
||||||
def _split_by_ast(
|
def _split_by_ast(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue