""" 代码检索器 支持语义检索和混合检索 """ import re import asyncio import logging from typing import List, Dict, Any, Optional from dataclasses import dataclass, field from .embeddings import EmbeddingService from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore from .splitter import CodeChunk, ChunkType from app.core.config import settings logger = logging.getLogger(__name__) @dataclass class RetrievalResult: """检索结果""" chunk_id: str content: str file_path: str language: str chunk_type: str line_start: int line_end: int score: float # 相似度分数 (0-1, 越高越相似) # 可选的元数据 name: Optional[str] = None parent_name: Optional[str] = None signature: Optional[str] = None security_indicators: List[str] = field(default_factory=list) # 原始元数据 metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "chunk_id": self.chunk_id, "content": self.content, "file_path": self.file_path, "language": self.language, "chunk_type": self.chunk_type, "line_start": self.line_start, "line_end": self.line_end, "score": self.score, "name": self.name, "parent_name": self.parent_name, "signature": self.signature, "security_indicators": self.security_indicators, } def to_context_string(self, include_metadata: bool = True) -> str: """转换为上下文字符串(用于 LLM 输入)""" parts = [] if include_metadata: header = f"File: {self.file_path}" if self.line_start and self.line_end: header += f" (lines {self.line_start}-{self.line_end})" if self.name: header += f"\n{self.chunk_type.title()}: {self.name}" if self.parent_name: header += f" in {self.parent_name}" parts.append(header) parts.append(f"```{self.language}\n{self.content}\n```") return "\n".join(parts) class CodeRetriever: """ 代码检索器 支持语义检索、关键字检索和混合检索 🔥 自动兼容不同维度的向量: - 查询时自动检测 collection 的 embedding 配置 - 动态创建对应的 embedding 服务 """ def __init__( self, collection_name: str, embedding_service: Optional[EmbeddingService] = None, vector_store: Optional[VectorStore] = None, persist_directory: Optional[str] = None, api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务 ): """ 初始化检索器 Args: collection_name: 向量集合名称 embedding_service: 嵌入服务(可选,会根据 collection 配置自动创建) vector_store: 向量存储 persist_directory: 持久化目录 api_key: API Key(用于动态创建 embedding 服务) """ self.collection_name = collection_name self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务 self.embedding_service = embedding_service or EmbeddingService() # 实际使用的 embedding 服务 self._api_key = api_key or (getattr(self.embedding_service, 'api_key', None) if self.embedding_service else None) # 创建向量存储 if vector_store: self.vector_store = vector_store else: try: self.vector_store = ChromaVectorStore( collection_name=collection_name, persist_directory=persist_directory, ) except ImportError: logger.warning("Chroma not available, using in-memory store") self.vector_store = InMemoryVectorStore(collection_name=collection_name) self._initialized = False async def initialize(self): """初始化检索器,自动检测并适配 collection 的 embedding 配置""" if self._initialized: return await self.vector_store.initialize() # 🔥 自动检测 collection 的 embedding 配置 if hasattr(self.vector_store, 'get_embedding_config'): stored_config = self.vector_store.get_embedding_config() stored_provider = stored_config.get("provider") stored_model = stored_config.get("model") stored_dimension = stored_config.get("dimension") stored_base_url = stored_config.get("base_url") # 🔥 如果没有存储的配置(旧的 collection),尝试通过维度推断 if not stored_provider or not stored_model: inferred = await self._infer_embedding_config_from_dimension() if inferred: stored_provider = inferred.get("provider") stored_model = inferred.get("model") stored_dimension = inferred.get("dimension") logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}") if stored_provider and stored_model: # 检查是否需要使用不同的 embedding 服务 current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None if current_provider != stored_provider or current_model != stored_model: logger.info( f"🔄 Collection 使用的 embedding 配置与当前不同: " f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs " f"{current_provider}/{current_model}" ) logger.info(f"🔄 自动切换到 collection 的 embedding 配置") # 动态创建对应的 embedding 服务 api_key = self._api_key if not api_key and self._provided_embedding_service: api_key = getattr(self._provided_embedding_service, 'api_key', None) # 🔥 重要:如果用户显式指定了维度,且与存储的维度不匹配,则不应自动切换(会导致报错) explicit_dim = getattr(settings, "EMBEDDING_DIMENSION", 0) if explicit_dim > 0 and explicit_dim != stored_dimension: logger.warning(f"⚠️ Collection 维度 ({stored_dimension}) 与显式指定的维度 ({explicit_dim}) 不匹配,跳过自动切换以避免错误。") return self.embedding_service = EmbeddingService( provider=stored_provider, model=stored_model, api_key=api_key, base_url=stored_base_url, ) logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}") # 如果仍然没有 embedding 服务,创建默认的 if not self.embedding_service: self.embedding_service = EmbeddingService() self._initialized = True async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]: """ 🔥 从向量维度推断 embedding 配置(用于处理旧的 collection) Returns: 推断的 embedding 配置,如果无法推断则返回 None """ try: # 获取一个样本向量来检查维度 if hasattr(self.vector_store, '_collection') and self.vector_store._collection: count = await self.vector_store.get_count() if count > 0: sample = await asyncio.to_thread( self.vector_store._collection.peek, limit=1 ) embeddings = sample.get("embeddings") if embeddings is not None and len(embeddings) > 0: dim = len(embeddings[0]) # 🔥 2. Fallback to hardcoded mapping dimension_mapping = { # OpenAI 系列 1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536}, 3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072}, # Qwen (DashScope) 系列 # 1536: {"provider": "qwen", "model": "text-embedding-v2", "dimension": 1536}, 1024: {"provider": "qwen", "model": "text-embedding-v4", "dimension": 1024}, # HuggingFace 系列 # 1024 已被 Qwen 占用,优先选用 Qwen 如果配置了 # 1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024}, 384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384}, # Ollama 系列 768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768}, # Jina 系列 512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512}, } # Special case: If EMBEDDING_PROVIDER is 'qwen', use qwen for 1536 too if settings.EMBEDDING_PROVIDER == "qwen" and dim == 1536: return {"provider": "qwen", "model": "text-embedding-v2", "dimension": 1536} inferred = dimension_mapping.get(dim) if inferred: logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}") return inferred except Exception as e: logger.warning(f"无法推断 embedding 配置: {e}") return None def get_collection_embedding_config(self) -> Dict[str, Any]: """ 获取 collection 存储的 embedding 配置 Returns: 包含 provider, model, dimension, base_url 的字典 """ if hasattr(self.vector_store, 'get_embedding_config'): return self.vector_store.get_embedding_config() return {} async def retrieve( self, query: str, top_k: int = 10, filter_file_path: Optional[str] = None, filter_language: Optional[str] = None, filter_chunk_type: Optional[str] = None, min_score: float = 0.0, ) -> List[RetrievalResult]: """ 语义检索 Args: query: 查询文本 top_k: 返回数量 filter_file_path: 文件路径过滤 filter_language: 语言过滤 filter_chunk_type: 块类型过滤 min_score: 最小相似度分数 Returns: 检索结果列表 """ await self.initialize() # 生成查询嵌入 query_embedding = await self.embedding_service.embed(query) # 构建过滤条件 where = {} if filter_file_path: where["file_path"] = filter_file_path if filter_language: where["language"] = filter_language if filter_chunk_type: where["chunk_type"] = filter_chunk_type # 查询向量存储 raw_results = await self.vector_store.query( query_embedding=query_embedding, n_results=top_k * 2, # 多查一些,后面过滤 where=where if where else None, ) # 转换结果 results = [] for i, (id_, doc, meta, dist) in enumerate(zip( raw_results["ids"], raw_results["documents"], raw_results["metadatas"], raw_results["distances"], )): # 将距离转换为相似度分数 (余弦距离) score = 1 - dist if score < min_score: continue # 解析安全指标(可能是 JSON 字符串) security_indicators = meta.get("security_indicators", []) if isinstance(security_indicators, str): try: import json security_indicators = json.loads(security_indicators) except: security_indicators = [] result = RetrievalResult( chunk_id=id_, content=doc, file_path=meta.get("file_path", ""), language=meta.get("language", "text"), chunk_type=meta.get("chunk_type", "unknown"), line_start=meta.get("line_start", 0), line_end=meta.get("line_end", 0), score=score, name=meta.get("name"), parent_name=meta.get("parent_name"), signature=meta.get("signature"), security_indicators=security_indicators, metadata=meta, ) results.append(result) # 按分数排序并截取 results.sort(key=lambda x: x.score, reverse=True) return results[:top_k] async def retrieve_by_file( self, file_path: str, top_k: int = 50, ) -> List[RetrievalResult]: """ 按文件路径检索 Args: file_path: 文件路径 top_k: 返回数量 Returns: 该文件的所有代码块 """ await self.initialize() # 使用一个通用查询 query_embedding = await self.embedding_service.embed(f"code in {file_path}") raw_results = await self.vector_store.query( query_embedding=query_embedding, n_results=top_k, where={"file_path": file_path}, ) results = [] for id_, doc, meta, dist in zip( raw_results["ids"], raw_results["documents"], raw_results["metadatas"], raw_results["distances"], ): result = RetrievalResult( chunk_id=id_, content=doc, file_path=meta.get("file_path", ""), language=meta.get("language", "text"), chunk_type=meta.get("chunk_type", "unknown"), line_start=meta.get("line_start", 0), line_end=meta.get("line_end", 0), score=1 - dist, name=meta.get("name"), parent_name=meta.get("parent_name"), metadata=meta, ) results.append(result) # 按行号排序 results.sort(key=lambda x: x.line_start) return results async def retrieve_security_related( self, vulnerability_type: Optional[str] = None, top_k: int = 20, ) -> List[RetrievalResult]: """ 检索与安全相关的代码 Args: vulnerability_type: 漏洞类型(如 sql_injection, xss 等) top_k: 返回数量 Returns: 安全相关的代码块 """ # 根据漏洞类型构建查询 security_queries = { "sql_injection": "SQL query execute database user input", "xss": "HTML render user input innerHTML template", "command_injection": "system exec command shell subprocess", "path_traversal": "file path read open user input", "ssrf": "HTTP request URL user input fetch", "deserialization": "deserialize pickle yaml load object", "auth_bypass": "authentication login password token session", "hardcoded_secret": "password secret key token credential", } if vulnerability_type and vulnerability_type in security_queries: query = security_queries[vulnerability_type] else: query = "security vulnerability dangerous function user input" return await self.retrieve(query, top_k=top_k) async def retrieve_function_context( self, function_name: str, file_path: Optional[str] = None, include_callers: bool = True, include_callees: bool = True, top_k: int = 10, ) -> Dict[str, List[RetrievalResult]]: """ 检索函数上下文 Args: function_name: 函数名 file_path: 文件路径(可选) include_callers: 是否包含调用者 include_callees: 是否包含被调用者 top_k: 每类返回数量 Returns: 包含函数定义、调用者、被调用者的字典 """ context = { "definition": [], "callers": [], "callees": [], } # 查找函数定义 definition_query = f"function definition {function_name}" definitions = await self.retrieve( definition_query, top_k=5, filter_file_path=file_path, ) # 过滤出真正的定义 for result in definitions: if result.name == function_name or function_name in (result.content or ""): context["definition"].append(result) if include_callers: # 查找调用此函数的代码 caller_query = f"calls {function_name} invoke {function_name}" callers = await self.retrieve(caller_query, top_k=top_k) for result in callers: # 检查是否真的调用了这个函数 if re.search(rf'\b{re.escape(function_name)}\s*\(', result.content): if result not in context["definition"]: context["callers"].append(result) if include_callees and context["definition"]: # 从函数定义中提取调用的其他函数 for definition in context["definition"]: calls = re.findall(r'\b(\w+)\s*\(', definition.content) unique_calls = list(set(calls))[:5] # 限制数量 for call in unique_calls: if call == function_name: continue callees = await self.retrieve( f"function {call} definition", top_k=2, ) context["callees"].extend(callees) return context async def retrieve_similar_code( self, code_snippet: str, top_k: int = 5, exclude_file: Optional[str] = None, ) -> List[RetrievalResult]: """ 检索相似的代码 Args: code_snippet: 代码片段 top_k: 返回数量 exclude_file: 排除的文件 Returns: 相似代码列表 """ results = await self.retrieve( f"similar code: {code_snippet}", top_k=top_k * 2, ) if exclude_file: results = [r for r in results if r.file_path != exclude_file] return results[:top_k] async def hybrid_retrieve( self, query: str, keywords: Optional[List[str]] = None, top_k: int = 10, semantic_weight: float = 0.7, ) -> List[RetrievalResult]: """ 混合检索(语义 + 关键字) Args: query: 查询文本 keywords: 额外的关键字 top_k: 返回数量 semantic_weight: 语义检索权重 Returns: 检索结果列表 """ # 语义检索 semantic_results = await self.retrieve(query, top_k=top_k * 2) # 如果有关键字,进行关键字过滤/增强 if keywords: keyword_pattern = '|'.join(re.escape(kw) for kw in keywords) enhanced_results = [] for result in semantic_results: # 计算关键字匹配度 matches = len(re.findall(keyword_pattern, result.content, re.IGNORECASE)) keyword_score = min(1.0, matches / len(keywords)) # 混合分数 hybrid_score = ( semantic_weight * result.score + (1 - semantic_weight) * keyword_score ) result.score = hybrid_score enhanced_results.append(result) enhanced_results.sort(key=lambda x: x.score, reverse=True) return enhanced_results[:top_k] return semantic_results[:top_k] def format_results_for_llm( self, results: List[RetrievalResult], max_tokens: int = 4000, include_metadata: bool = True, ) -> str: """ 将检索结果格式化为 LLM 输入 Args: results: 检索结果 max_tokens: 最大 Token 数 include_metadata: 是否包含元数据 Returns: 格式化的字符串 """ if not results: return "No relevant code found." parts = [] total_tokens = 0 for i, result in enumerate(results): context = result.to_context_string(include_metadata=include_metadata) estimated_tokens = len(context) // 4 if total_tokens + estimated_tokens > max_tokens: break parts.append(f"### Code Block {i + 1} (Score: {result.score:.2f})\n{context}") total_tokens += estimated_tokens return "\n\n".join(parts)