feat(retriever): 添加自动适配不同 embedding 配置的功能

支持自动检测 collection 的 embedding 配置并动态创建对应的 embedding 服务
新增从向量维度推断配置的功能,兼容旧的 collection
This commit is contained in:
lintsinghua 2025-12-16 15:28:03 +08:00
parent 3176c35817
commit 17889dceee
4 changed files with 194 additions and 32 deletions

View File

@ -644,10 +644,12 @@ async def _initialize_tools(
collection_name = f"project_{project_id}" if project_id else "default_project" collection_name = f"project_{project_id}" if project_id else "default_project"
# 创建 CodeRetriever用于搜索 # 创建 CodeRetriever用于搜索
# 🔥 传递 api_key用于自动适配 collection 的 embedding 配置
retriever = CodeRetriever( retriever = CodeRetriever(
collection_name=collection_name, collection_name=collection_name,
embedding_service=embedding_service, embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH, persist_directory=settings.VECTOR_DB_PATH,
api_key=embedding_api_key, # 🔥 传递 api_key 以支持自动切换 embedding
) )
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}") logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")

View File

@ -498,19 +498,21 @@ class EmbeddingService:
self.cache_enabled = cache_enabled self.cache_enabled = cache_enabled
self._cache: Dict[str, List[float]] = {} self._cache: Dict[str, List[float]] = {}
# 确定提供商 # 确定提供商(保存原始值用于属性访问)
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key
self.base_url = base_url
# 创建提供商实例 # 创建提供商实例
self._provider = self._create_provider( self._provider = self._create_provider(
provider=provider, provider=self.provider,
model=model, model=self.model,
api_key=api_key, api_key=api_key,
base_url=base_url, base_url=base_url,
) )
logger.info(f"Embedding service initialized with {provider}/{model}") logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
def _create_provider( def _create_provider(
self, self,

View File

@ -114,9 +114,11 @@ class ChromaVectorStore(VectorStore):
self, self,
collection_name: str, collection_name: str,
persist_directory: Optional[str] = None, persist_directory: Optional[str] = None,
embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增embedding 配置
): ):
self.collection_name = collection_name self.collection_name = collection_name
self.persist_directory = persist_directory self.persist_directory = persist_directory
self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置
self._client = None self._client = None
self._collection = None self._collection = None
@ -136,9 +138,19 @@ class ChromaVectorStore(VectorStore):
settings=Settings(anonymized_telemetry=False), settings=Settings(anonymized_telemetry=False),
) )
# 🔥 构建 collection 元数据,包含 embedding 配置
collection_metadata = {"hnsw:space": "cosine"}
if self.embedding_config:
# 在元数据中记录 embedding 配置
collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai")
collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small")
collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536)
if self.embedding_config.get("base_url"):
collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url")
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=self.collection_name, name=self.collection_name,
metadata={"hnsw:space": "cosine"}, metadata=collection_metadata,
) )
logger.info(f"Chroma collection '{self.collection_name}' initialized") logger.info(f"Chroma collection '{self.collection_name}' initialized")
@ -146,6 +158,24 @@ class ChromaVectorStore(VectorStore):
except ImportError: except ImportError:
raise ImportError("chromadb is required. Install with: pip install chromadb") raise ImportError("chromadb is required. Install with: pip install chromadb")
def get_embedding_config(self) -> Dict[str, Any]:
"""
🔥 获取 collection embedding 配置
Returns:
包含 provider, model, dimension, base_url 的字典
"""
if not self._collection:
return {}
metadata = self._collection.metadata or {}
return {
"provider": metadata.get("embedding_provider"),
"model": metadata.get("embedding_model"),
"dimension": metadata.get("embedding_dimension"),
"base_url": metadata.get("embedding_base_url"),
}
async def add_documents( async def add_documents(
self, self,
ids: List[str], ids: List[str],
@ -335,6 +365,14 @@ class CodeIndexer:
self.embedding_service = embedding_service or EmbeddingService() self.embedding_service = embedding_service or EmbeddingService()
self.splitter = splitter or CodeSplitter() self.splitter = splitter or CodeSplitter()
# 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据
embedding_config = {
"provider": getattr(self.embedding_service, 'provider', 'openai'),
"model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'),
"dimension": getattr(self.embedding_service, 'dimension', 1536),
"base_url": getattr(self.embedding_service, 'base_url', None),
}
# 创建向量存储 # 创建向量存储
if vector_store: if vector_store:
self.vector_store = vector_store self.vector_store = vector_store
@ -343,6 +381,7 @@ class CodeIndexer:
self.vector_store = ChromaVectorStore( self.vector_store = ChromaVectorStore(
collection_name=collection_name, collection_name=collection_name,
persist_directory=persist_directory, persist_directory=persist_directory,
embedding_config=embedding_config, # 🔥 传递 embedding 配置
) )
except ImportError: except ImportError:
logger.warning("Chroma not available, using in-memory store") logger.warning("Chroma not available, using in-memory store")

View File

@ -4,6 +4,7 @@
""" """
import re import re
import asyncio
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -75,6 +76,10 @@ class CodeRetriever:
""" """
代码检索器 代码检索器
支持语义检索关键字检索和混合检索 支持语义检索关键字检索和混合检索
🔥 自动兼容不同维度的向量
- 查询时自动检测 collection embedding 配置
- 动态创建对应的 embedding 服务
""" """
def __init__( def __init__(
@ -83,18 +88,22 @@ class CodeRetriever:
embedding_service: Optional[EmbeddingService] = None, embedding_service: Optional[EmbeddingService] = None,
vector_store: Optional[VectorStore] = None, vector_store: Optional[VectorStore] = None,
persist_directory: Optional[str] = None, persist_directory: Optional[str] = None,
api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务
): ):
""" """
初始化检索器 初始化检索器
Args: Args:
collection_name: 向量集合名称 collection_name: 向量集合名称
embedding_service: 嵌入服务 embedding_service: 嵌入服务可选会根据 collection 配置自动创建
vector_store: 向量存储 vector_store: 向量存储
persist_directory: 持久化目录 persist_directory: 持久化目录
api_key: API Key用于动态创建 embedding 服务
""" """
self.collection_name = collection_name self.collection_name = collection_name
self.embedding_service = embedding_service or EmbeddingService() self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务
self.embedding_service = embedding_service # 实际使用的 embedding 服务
self._api_key = api_key
# 创建向量存储 # 创建向量存储
if vector_store: if vector_store:
@ -112,10 +121,120 @@ class CodeRetriever:
self._initialized = False self._initialized = False
async def initialize(self): async def initialize(self):
"""初始化检索器""" """初始化检索器,自动检测并适配 collection 的 embedding 配置"""
if not self._initialized: if self._initialized:
await self.vector_store.initialize() return
self._initialized = True
await self.vector_store.initialize()
# 🔥 自动检测 collection 的 embedding 配置
if hasattr(self.vector_store, 'get_embedding_config'):
stored_config = self.vector_store.get_embedding_config()
stored_provider = stored_config.get("provider")
stored_model = stored_config.get("model")
stored_dimension = stored_config.get("dimension")
stored_base_url = stored_config.get("base_url")
# 🔥 如果没有存储的配置(旧的 collection尝试通过维度推断
if not stored_provider or not stored_model:
inferred = await self._infer_embedding_config_from_dimension()
if inferred:
stored_provider = inferred.get("provider")
stored_model = inferred.get("model")
stored_dimension = inferred.get("dimension")
logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}")
if stored_provider and stored_model:
# 检查是否需要使用不同的 embedding 服务
current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None
current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None
if current_provider != stored_provider or current_model != stored_model:
logger.info(
f"🔄 Collection 使用的 embedding 配置与当前不同: "
f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs "
f"{current_provider}/{current_model}"
)
logger.info(f"🔄 自动切换到 collection 的 embedding 配置")
# 动态创建对应的 embedding 服务
api_key = self._api_key
if not api_key and self._provided_embedding_service:
api_key = getattr(self._provided_embedding_service, 'api_key', None)
self.embedding_service = EmbeddingService(
provider=stored_provider,
model=stored_model,
api_key=api_key,
base_url=stored_base_url,
)
logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}")
# 如果仍然没有 embedding 服务,创建默认的
if not self.embedding_service:
self.embedding_service = EmbeddingService()
self._initialized = True
async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]:
"""
🔥 从向量维度推断 embedding 配置用于处理旧的 collection
Returns:
推断的 embedding 配置如果无法推断则返回 None
"""
try:
# 获取一个样本向量来检查维度
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
count = await self.vector_store.get_count()
if count > 0:
sample = await asyncio.to_thread(
self.vector_store._collection.peek,
limit=1
)
embeddings = sample.get("embeddings")
if embeddings is not None and len(embeddings) > 0:
dim = len(embeddings[0])
# 🔥 根据维度推断模型(优先选择常用模型)
dimension_mapping = {
# OpenAI 系列
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072},
# HuggingFace 系列
1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024},
384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384},
# Ollama 系列
768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768},
# Jina 系列
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
# Cohere 系列
# 1024 已被 HuggingFace 占用Cohere 维度相同时会默认使用 HuggingFace
}
inferred = dimension_mapping.get(dim)
if inferred:
logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}")
return inferred
except Exception as e:
logger.warning(f"无法推断 embedding 配置: {e}")
return None
def get_collection_embedding_config(self) -> Dict[str, Any]:
"""
获取 collection 存储的 embedding 配置
Returns:
包含 provider, model, dimension, base_url 的字典
"""
if hasattr(self.vector_store, 'get_embedding_config'):
return self.vector_store.get_embedding_config()
return {}
async def retrieve( async def retrieve(
self, self,