From cd8fb49a5614ca9e9efc03244fe04126a9f4bf34 Mon Sep 17 00:00:00 2001 From: vinland100 Date: Fri, 9 Jan 2026 16:12:55 +0800 Subject: [PATCH] resolve indexing hangs and optimize code splitting performance --- backend/app/services/rag/splitter.py | 102 +++++++++++++++++++++------ 1 file changed, 80 insertions(+), 22 deletions(-) diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index cfc6002..57b909b 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -7,6 +7,8 @@ import re import asyncio import hashlib import logging +import threading +import time from typing import List, Dict, Any, Optional, Tuple, Set from dataclasses import dataclass, field from enum import Enum @@ -14,6 +16,9 @@ from pathlib import Path logger = logging.getLogger(__name__) +# 全局缓存 tiktoken 编码器,避免在 CodeChunk 初始化时重复加载 +_TIKTOKEN_ENCODING = None + class ChunkType(Enum): """代码块类型""" @@ -85,12 +90,18 @@ class CodeChunk: def _estimate_tokens(self) -> int: # 使用 tiktoken 如果可用 + global _TIKTOKEN_ENCODING try: - import tiktoken - enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode(self.content)) + if _TIKTOKEN_ENCODING is None: + import tiktoken + _TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base") + return len(_TIKTOKEN_ENCODING.encode(self.content)) except ImportError: return len(self.content) // 4 + except Exception as e: + # 避免编码过程中的异常导致整个任务失败 + logger.debug(f"Token estimation error: {e}") + return len(self.content) // 4 def to_dict(self) -> Dict[str, Any]: result = { @@ -254,12 +265,21 @@ class TreeSitterParser: } def __init__(self): - self._parsers: Dict[str, Any] = {} + # 使用线程本地存储来保存解析器,因为 Tree-sitter Parser 对象不是线程安全的 + # asyncio.to_thread 会在不同的线程中执行,如果共享同一个 parser 会导致崩溃或挂起 + self._local = threading.local() self._initialized = False + def _get_parsers(self) -> Dict[str, Any]: + """获取当前线程的解析器字典""" + if not hasattr(self._local, 'parsers'): + self._local.parsers = {} + return self._local.parsers + def _ensure_initialized(self, language: str) -> bool: """确保语言解析器已初始化""" - if language in self._parsers: + parsers = self._get_parsers() + if language in parsers: return True # 检查语言是否受支持 @@ -271,7 +291,7 @@ class TreeSitterParser: from tree_sitter_language_pack import get_parser parser = get_parser(language) - self._parsers[language] = parser + parsers[language] = parser return True except ImportError: @@ -286,7 +306,8 @@ class TreeSitterParser: if not self._ensure_initialized(language): return None - parser = self._parsers.get(language) + parsers = self._get_parsers() + parser = parsers.get(language) if not parser: return None @@ -497,7 +518,11 @@ class CodeSplitter: def detect_language(self, file_path: str) -> str: """检测编程语言""" ext = Path(file_path).suffix.lower() - return TreeSitterParser.LANGUAGE_MAP.get(ext, "text") + lang = TreeSitterParser.LANGUAGE_MAP.get(ext, "text") + # 针对 .cs 文件,确保映射到 csharp + if ext == ".cs": + return "csharp" + return lang def split_file( self, @@ -507,21 +532,19 @@ class CodeSplitter: ) -> List[CodeChunk]: """ 分割单个文件 - - Args: - content: 文件内容 - file_path: 文件路径 - language: 编程语言(可选) - - Returns: - 代码块列表 """ if not content or not content.strip(): return [] if language is None: language = self.detect_language(file_path) + + # 记录开始处理大型文件 + file_size_kb = len(content) / 1024 + if file_size_kb > 100: + logger.debug(f"Starting to split large file: {file_path} ({file_size_kb:.1f} KB, lang: {language})") + start_time = time.time() chunks = [] try: @@ -561,6 +584,10 @@ class CodeSplitter: # 后处理:使用语义分析增强 self._enrich_chunks_with_semantics(chunks, content, language) + if file_size_kb > 100: + elapsed = time.time() - start_time + logger.debug(f"Finished splitting {file_path}: {len(chunks)} chunks, took {elapsed:.2f}s") + except Exception as e: logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") chunks = self._split_by_lines(content, file_path, language) @@ -939,12 +966,40 @@ class CodeSplitter: def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]: """过滤与代码块相关的导入""" - relevant = [] + if not all_imports: + return [] + + # 提取导入的末尾部分(通常是类名或模块名) + names_map = {} for imp in all_imports: - module_name = imp.split('.')[-1] - if re.search(rf'\b{re.escape(module_name)}\b', chunk_content): - relevant.append(imp) - return relevant[:20] + name = imp.split('.')[-1] + if name: + if name not in names_map: + names_map[name] = [] + names_map[name].append(imp) + + if not names_map: + return [] + + # 构建一个综合正则,一次性匹配所有导入名称,效率远高于循环 re.search + try: + # 限制名称数量,避免正则过大 + target_names = list(names_map.keys())[:200] + pattern = rf"\b({'|'.join(re.escape(name) for name in target_names)})\b" + matches = set(re.findall(pattern, chunk_content)) + + relevant = [] + for m in matches: + relevant.extend(names_map[m]) + return relevant[:20] + except Exception: + # 回退到简单搜索 + relevant = [] + for name, imps in names_map.items(): + if name in chunk_content: + relevant.extend(imps) + if len(relevant) >= 20: break + return relevant[:20] def _extract_function_calls(self, content: str, language: str) -> List[str]: """提取函数调用""" @@ -979,7 +1034,10 @@ class CodeSplitter: ], "csharp": [ r"(?:class|record|struct|enum|interface)\s+(\w+)", - r"[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)", + # 备用 + # r"[^\s<>\[\],]+\s+(\w+)\s*\(", + # 究极版:支持基础类型、泛型、可空类型 (?)、数组 ([,] [][])、元组返回类型 ((int, string)) 以及构造函数 + r"(?:(?:(?:[a-zA-Z_][\w.]*(?:<[\w\s,<>]+>)?\??|(?:\([\w\s,<>?]+\)))(?:\[[\s,]*\])*\s+)+)?(\w+)\s*(?:<[\w\s,<>]+>)?\s*\(", ], "cpp": [ r"(?:class|struct)\s+(\w+)",