From 17889dceeed8b752cae382138b7ebcefbb40a1c6 Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Tue, 16 Dec 2025 15:28:03 +0800 Subject: [PATCH] =?UTF-8?q?feat(retriever):=20=E6=B7=BB=E5=8A=A0=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E9=80=82=E9=85=8D=E4=B8=8D=E5=90=8C=20embedding=20?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 支持自动检测 collection 的 embedding 配置并动态创建对应的 embedding 服务 新增从向量维度推断配置的功能,兼容旧的 collection --- backend/app/api/v1/endpoints/agent_tasks.py | 2 + backend/app/services/rag/embeddings.py | 22 +-- backend/app/services/rag/indexer.py | 61 +++++++-- backend/app/services/rag/retriever.py | 141 ++++++++++++++++++-- 4 files changed, 194 insertions(+), 32 deletions(-) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 420c4f6..9806a23 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -644,10 +644,12 @@ async def _initialize_tools( collection_name = f"project_{project_id}" if project_id else "default_project" # 创建 CodeRetriever(用于搜索) + # 🔥 传递 api_key,用于自动适配 collection 的 embedding 配置 retriever = CodeRetriever( collection_name=collection_name, embedding_service=embedding_service, persist_directory=settings.VECTOR_DB_PATH, + api_key=embedding_api_key, # 🔥 传递 api_key 以支持自动切换 embedding ) logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}") diff --git a/backend/app/services/rag/embeddings.py b/backend/app/services/rag/embeddings.py index a2ec119..92008e2 100644 --- a/backend/app/services/rag/embeddings.py +++ b/backend/app/services/rag/embeddings.py @@ -487,7 +487,7 @@ class EmbeddingService: ): """ 初始化嵌入服务 - + Args: provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina) model: 模型名称 @@ -497,20 +497,22 @@ class EmbeddingService: """ self.cache_enabled = cache_enabled self._cache: Dict[str, List[float]] = {} - - # 确定提供商 - provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') - model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') - + + # 确定提供商(保存原始值用于属性访问) + self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') + self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') + self.api_key = api_key + self.base_url = base_url + # 创建提供商实例 self._provider = self._create_provider( - provider=provider, - model=model, + provider=self.provider, + model=self.model, api_key=api_key, base_url=base_url, ) - - logger.info(f"Embedding service initialized with {provider}/{model}") + + logger.info(f"Embedding service initialized with {self.provider}/{self.model}") def _create_provider( self, diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 45607fc..3cffa77 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -109,23 +109,25 @@ class VectorStore: class ChromaVectorStore(VectorStore): """Chroma 向量存储""" - + def __init__( self, collection_name: str, persist_directory: Optional[str] = None, + embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增:embedding 配置 ): self.collection_name = collection_name self.persist_directory = persist_directory + self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置 self._client = None self._collection = None - + async def initialize(self): """初始化 Chroma""" try: import chromadb from chromadb.config import Settings - + if self.persist_directory: self._client = chromadb.PersistentClient( path=self.persist_directory, @@ -135,16 +137,44 @@ class ChromaVectorStore(VectorStore): self._client = chromadb.Client( settings=Settings(anonymized_telemetry=False), ) - + + # 🔥 构建 collection 元数据,包含 embedding 配置 + collection_metadata = {"hnsw:space": "cosine"} + if self.embedding_config: + # 在元数据中记录 embedding 配置 + collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai") + collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small") + collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536) + if self.embedding_config.get("base_url"): + collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url") + self._collection = self._client.get_or_create_collection( name=self.collection_name, - metadata={"hnsw:space": "cosine"}, + metadata=collection_metadata, ) - + logger.info(f"Chroma collection '{self.collection_name}' initialized") - + except ImportError: raise ImportError("chromadb is required. Install with: pip install chromadb") + + def get_embedding_config(self) -> Dict[str, Any]: + """ + 🔥 获取 collection 的 embedding 配置 + + Returns: + 包含 provider, model, dimension, base_url 的字典 + """ + if not self._collection: + return {} + + metadata = self._collection.metadata or {} + return { + "provider": metadata.get("embedding_provider"), + "model": metadata.get("embedding_model"), + "dimension": metadata.get("embedding_dimension"), + "base_url": metadata.get("embedding_base_url"), + } async def add_documents( self, @@ -312,7 +342,7 @@ class CodeIndexer: 代码索引器 将代码文件分块、嵌入并索引到向量数据库 """ - + def __init__( self, collection_name: str, @@ -323,7 +353,7 @@ class CodeIndexer: ): """ 初始化索引器 - + Args: collection_name: 向量集合名称 embedding_service: 嵌入服务 @@ -334,7 +364,15 @@ class CodeIndexer: self.collection_name = collection_name self.embedding_service = embedding_service or EmbeddingService() self.splitter = splitter or CodeSplitter() - + + # 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据 + embedding_config = { + "provider": getattr(self.embedding_service, 'provider', 'openai'), + "model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'), + "dimension": getattr(self.embedding_service, 'dimension', 1536), + "base_url": getattr(self.embedding_service, 'base_url', None), + } + # 创建向量存储 if vector_store: self.vector_store = vector_store @@ -343,11 +381,12 @@ class CodeIndexer: self.vector_store = ChromaVectorStore( collection_name=collection_name, persist_directory=persist_directory, + embedding_config=embedding_config, # 🔥 传递 embedding 配置 ) except ImportError: logger.warning("Chroma not available, using in-memory store") self.vector_store = InMemoryVectorStore(collection_name=collection_name) - + self._initialized = False async def initialize(self): diff --git a/backend/app/services/rag/retriever.py b/backend/app/services/rag/retriever.py index d285bc3..d942b11 100644 --- a/backend/app/services/rag/retriever.py +++ b/backend/app/services/rag/retriever.py @@ -4,6 +4,7 @@ """ import re +import asyncio import logging from typing import List, Dict, Any, Optional from dataclasses import dataclass, field @@ -75,27 +76,35 @@ class CodeRetriever: """ 代码检索器 支持语义检索、关键字检索和混合检索 + + 🔥 自动兼容不同维度的向量: + - 查询时自动检测 collection 的 embedding 配置 + - 动态创建对应的 embedding 服务 """ - + def __init__( self, collection_name: str, embedding_service: Optional[EmbeddingService] = None, vector_store: Optional[VectorStore] = None, persist_directory: Optional[str] = None, + api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务 ): """ 初始化检索器 - + Args: collection_name: 向量集合名称 - embedding_service: 嵌入服务 + embedding_service: 嵌入服务(可选,会根据 collection 配置自动创建) vector_store: 向量存储 persist_directory: 持久化目录 + api_key: API Key(用于动态创建 embedding 服务) """ self.collection_name = collection_name - self.embedding_service = embedding_service or EmbeddingService() - + self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务 + self.embedding_service = embedding_service # 实际使用的 embedding 服务 + self._api_key = api_key + # 创建向量存储 if vector_store: self.vector_store = vector_store @@ -108,14 +117,124 @@ class CodeRetriever: except ImportError: logger.warning("Chroma not available, using in-memory store") self.vector_store = InMemoryVectorStore(collection_name=collection_name) - + self._initialized = False - + async def initialize(self): - """初始化检索器""" - if not self._initialized: - await self.vector_store.initialize() - self._initialized = True + """初始化检索器,自动检测并适配 collection 的 embedding 配置""" + if self._initialized: + return + + await self.vector_store.initialize() + + # 🔥 自动检测 collection 的 embedding 配置 + if hasattr(self.vector_store, 'get_embedding_config'): + stored_config = self.vector_store.get_embedding_config() + stored_provider = stored_config.get("provider") + stored_model = stored_config.get("model") + stored_dimension = stored_config.get("dimension") + stored_base_url = stored_config.get("base_url") + + # 🔥 如果没有存储的配置(旧的 collection),尝试通过维度推断 + if not stored_provider or not stored_model: + inferred = await self._infer_embedding_config_from_dimension() + if inferred: + stored_provider = inferred.get("provider") + stored_model = inferred.get("model") + stored_dimension = inferred.get("dimension") + logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}") + + if stored_provider and stored_model: + # 检查是否需要使用不同的 embedding 服务 + current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None + current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None + + if current_provider != stored_provider or current_model != stored_model: + logger.info( + f"🔄 Collection 使用的 embedding 配置与当前不同: " + f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs " + f"{current_provider}/{current_model}" + ) + logger.info(f"🔄 自动切换到 collection 的 embedding 配置") + + # 动态创建对应的 embedding 服务 + api_key = self._api_key + if not api_key and self._provided_embedding_service: + api_key = getattr(self._provided_embedding_service, 'api_key', None) + + self.embedding_service = EmbeddingService( + provider=stored_provider, + model=stored_model, + api_key=api_key, + base_url=stored_base_url, + ) + logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}") + + # 如果仍然没有 embedding 服务,创建默认的 + if not self.embedding_service: + self.embedding_service = EmbeddingService() + + self._initialized = True + + async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]: + """ + 🔥 从向量维度推断 embedding 配置(用于处理旧的 collection) + + Returns: + 推断的 embedding 配置,如果无法推断则返回 None + """ + try: + # 获取一个样本向量来检查维度 + if hasattr(self.vector_store, '_collection') and self.vector_store._collection: + count = await self.vector_store.get_count() + if count > 0: + sample = await asyncio.to_thread( + self.vector_store._collection.peek, + limit=1 + ) + embeddings = sample.get("embeddings") + if embeddings is not None and len(embeddings) > 0: + dim = len(embeddings[0]) + + # 🔥 根据维度推断模型(优先选择常用模型) + dimension_mapping = { + # OpenAI 系列 + 1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536}, + 3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072}, + + # HuggingFace 系列 + 1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024}, + 384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384}, + + # Ollama 系列 + 768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768}, + + # Jina 系列 + 512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512}, + + # Cohere 系列 + # 1024 已被 HuggingFace 占用,Cohere 维度相同时会默认使用 HuggingFace + } + + inferred = dimension_mapping.get(dim) + if inferred: + logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}") + return inferred + except Exception as e: + logger.warning(f"无法推断 embedding 配置: {e}") + + return None + + def get_collection_embedding_config(self) -> Dict[str, Any]: + """ + 获取 collection 存储的 embedding 配置 + + Returns: + 包含 provider, model, dimension, base_url 的字典 + """ + if hasattr(self.vector_store, 'get_embedding_config'): + return self.vector_store.get_embedding_config() + return {} async def retrieve( self,