resolve indexing hangs and optimize code splitting performance
Build and Push CodeReview / build (push) Waiting to run
Details
Build and Push CodeReview / build (push) Waiting to run
Details
This commit is contained in:
parent
f86e556d0d
commit
cd8fb49a56
|
|
@ -7,6 +7,8 @@ import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
@ -14,6 +16,9 @@ from pathlib import Path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 全局缓存 tiktoken 编码器,避免在 CodeChunk 初始化时重复加载
|
||||||
|
_TIKTOKEN_ENCODING = None
|
||||||
|
|
||||||
|
|
||||||
class ChunkType(Enum):
|
class ChunkType(Enum):
|
||||||
"""代码块类型"""
|
"""代码块类型"""
|
||||||
|
|
@ -85,12 +90,18 @@ class CodeChunk:
|
||||||
|
|
||||||
def _estimate_tokens(self) -> int:
|
def _estimate_tokens(self) -> int:
|
||||||
# 使用 tiktoken 如果可用
|
# 使用 tiktoken 如果可用
|
||||||
|
global _TIKTOKEN_ENCODING
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
if _TIKTOKEN_ENCODING is None:
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
import tiktoken
|
||||||
return len(enc.encode(self.content))
|
_TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(_TIKTOKEN_ENCODING.encode(self.content))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return len(self.content) // 4
|
return len(self.content) // 4
|
||||||
|
except Exception as e:
|
||||||
|
# 避免编码过程中的异常导致整个任务失败
|
||||||
|
logger.debug(f"Token estimation error: {e}")
|
||||||
|
return len(self.content) // 4
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
|
|
@ -254,12 +265,21 @@ class TreeSitterParser:
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._parsers: Dict[str, Any] = {}
|
# 使用线程本地存储来保存解析器,因为 Tree-sitter Parser 对象不是线程安全的
|
||||||
|
# asyncio.to_thread 会在不同的线程中执行,如果共享同一个 parser 会导致崩溃或挂起
|
||||||
|
self._local = threading.local()
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
|
def _get_parsers(self) -> Dict[str, Any]:
|
||||||
|
"""获取当前线程的解析器字典"""
|
||||||
|
if not hasattr(self._local, 'parsers'):
|
||||||
|
self._local.parsers = {}
|
||||||
|
return self._local.parsers
|
||||||
|
|
||||||
def _ensure_initialized(self, language: str) -> bool:
|
def _ensure_initialized(self, language: str) -> bool:
|
||||||
"""确保语言解析器已初始化"""
|
"""确保语言解析器已初始化"""
|
||||||
if language in self._parsers:
|
parsers = self._get_parsers()
|
||||||
|
if language in parsers:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查语言是否受支持
|
# 检查语言是否受支持
|
||||||
|
|
@ -271,7 +291,7 @@ class TreeSitterParser:
|
||||||
from tree_sitter_language_pack import get_parser
|
from tree_sitter_language_pack import get_parser
|
||||||
|
|
||||||
parser = get_parser(language)
|
parser = get_parser(language)
|
||||||
self._parsers[language] = parser
|
parsers[language] = parser
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -286,7 +306,8 @@ class TreeSitterParser:
|
||||||
if not self._ensure_initialized(language):
|
if not self._ensure_initialized(language):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
parser = self._parsers.get(language)
|
parsers = self._get_parsers()
|
||||||
|
parser = parsers.get(language)
|
||||||
if not parser:
|
if not parser:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -497,7 +518,11 @@ class CodeSplitter:
|
||||||
def detect_language(self, file_path: str) -> str:
|
def detect_language(self, file_path: str) -> str:
|
||||||
"""检测编程语言"""
|
"""检测编程语言"""
|
||||||
ext = Path(file_path).suffix.lower()
|
ext = Path(file_path).suffix.lower()
|
||||||
return TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
|
lang = TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
|
||||||
|
# 针对 .cs 文件,确保映射到 csharp
|
||||||
|
if ext == ".cs":
|
||||||
|
return "csharp"
|
||||||
|
return lang
|
||||||
|
|
||||||
def split_file(
|
def split_file(
|
||||||
self,
|
self,
|
||||||
|
|
@ -507,21 +532,19 @@ class CodeSplitter:
|
||||||
) -> List[CodeChunk]:
|
) -> List[CodeChunk]:
|
||||||
"""
|
"""
|
||||||
分割单个文件
|
分割单个文件
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 文件内容
|
|
||||||
file_path: 文件路径
|
|
||||||
language: 编程语言(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
代码块列表
|
|
||||||
"""
|
"""
|
||||||
if not content or not content.strip():
|
if not content or not content.strip():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if language is None:
|
if language is None:
|
||||||
language = self.detect_language(file_path)
|
language = self.detect_language(file_path)
|
||||||
|
|
||||||
|
# 记录开始处理大型文件
|
||||||
|
file_size_kb = len(content) / 1024
|
||||||
|
if file_size_kb > 100:
|
||||||
|
logger.debug(f"Starting to split large file: {file_path} ({file_size_kb:.1f} KB, lang: {language})")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -561,6 +584,10 @@ class CodeSplitter:
|
||||||
# 后处理:使用语义分析增强
|
# 后处理:使用语义分析增强
|
||||||
self._enrich_chunks_with_semantics(chunks, content, language)
|
self._enrich_chunks_with_semantics(chunks, content, language)
|
||||||
|
|
||||||
|
if file_size_kb > 100:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.debug(f"Finished splitting {file_path}: {len(chunks)} chunks, took {elapsed:.2f}s")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
|
||||||
chunks = self._split_by_lines(content, file_path, language)
|
chunks = self._split_by_lines(content, file_path, language)
|
||||||
|
|
@ -939,12 +966,40 @@ class CodeSplitter:
|
||||||
|
|
||||||
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
|
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
|
||||||
"""过滤与代码块相关的导入"""
|
"""过滤与代码块相关的导入"""
|
||||||
relevant = []
|
if not all_imports:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 提取导入的末尾部分(通常是类名或模块名)
|
||||||
|
names_map = {}
|
||||||
for imp in all_imports:
|
for imp in all_imports:
|
||||||
module_name = imp.split('.')[-1]
|
name = imp.split('.')[-1]
|
||||||
if re.search(rf'\b{re.escape(module_name)}\b', chunk_content):
|
if name:
|
||||||
relevant.append(imp)
|
if name not in names_map:
|
||||||
return relevant[:20]
|
names_map[name] = []
|
||||||
|
names_map[name].append(imp)
|
||||||
|
|
||||||
|
if not names_map:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 构建一个综合正则,一次性匹配所有导入名称,效率远高于循环 re.search
|
||||||
|
try:
|
||||||
|
# 限制名称数量,避免正则过大
|
||||||
|
target_names = list(names_map.keys())[:200]
|
||||||
|
pattern = rf"\b({'|'.join(re.escape(name) for name in target_names)})\b"
|
||||||
|
matches = set(re.findall(pattern, chunk_content))
|
||||||
|
|
||||||
|
relevant = []
|
||||||
|
for m in matches:
|
||||||
|
relevant.extend(names_map[m])
|
||||||
|
return relevant[:20]
|
||||||
|
except Exception:
|
||||||
|
# 回退到简单搜索
|
||||||
|
relevant = []
|
||||||
|
for name, imps in names_map.items():
|
||||||
|
if name in chunk_content:
|
||||||
|
relevant.extend(imps)
|
||||||
|
if len(relevant) >= 20: break
|
||||||
|
return relevant[:20]
|
||||||
|
|
||||||
def _extract_function_calls(self, content: str, language: str) -> List[str]:
|
def _extract_function_calls(self, content: str, language: str) -> List[str]:
|
||||||
"""提取函数调用"""
|
"""提取函数调用"""
|
||||||
|
|
@ -979,7 +1034,10 @@ class CodeSplitter:
|
||||||
],
|
],
|
||||||
"csharp": [
|
"csharp": [
|
||||||
r"(?:class|record|struct|enum|interface)\s+(\w+)",
|
r"(?:class|record|struct|enum|interface)\s+(\w+)",
|
||||||
r"[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)",
|
# 备用
|
||||||
|
# r"[^\s<>\[\],]+\s+(\w+)\s*\(",
|
||||||
|
# 究极版:支持基础类型、泛型、可空类型 (?)、数组 ([,] [][])、元组返回类型 ((int, string)) 以及构造函数
|
||||||
|
r"(?:(?:(?:[a-zA-Z_][\w.]*(?:<[\w\s,<>]+>)?\??|(?:\([\w\s,<>?]+\)))(?:\[[\s,]*\])*\s+)+)?(\w+)\s*(?:<[\w\s,<>]+>)?\s*\(",
|
||||||
],
|
],
|
||||||
"cpp": [
|
"cpp": [
|
||||||
r"(?:class|struct)\s+(\w+)",
|
r"(?:class|struct)\s+(\w+)",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue