feat(retriever): 添加自动适配不同 embedding 配置的功能

支持自动检测 collection 的 embedding 配置并动态创建对应的 embedding 服务
新增从向量维度推断配置的功能,兼容旧的 collection
This commit is contained in:
lintsinghua 2025-12-16 15:28:03 +08:00
parent 3176c35817
commit 17889dceee
4 changed files with 194 additions and 32 deletions

View File

@ -644,10 +644,12 @@ async def _initialize_tools(
collection_name = f"project_{project_id}" if project_id else "default_project"
# 创建 CodeRetriever用于搜索
# 🔥 传递 api_key用于自动适配 collection 的 embedding 配置
retriever = CodeRetriever(
collection_name=collection_name,
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
api_key=embedding_api_key, # 🔥 传递 api_key 以支持自动切换 embedding
)
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")

View File

@ -498,19 +498,21 @@ class EmbeddingService:
self.cache_enabled = cache_enabled
self._cache: Dict[str, List[float]] = {}
# 确定提供商
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
# 确定提供商(保存原始值用于属性访问)
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key
self.base_url = base_url
# 创建提供商实例
self._provider = self._create_provider(
provider=provider,
model=model,
provider=self.provider,
model=self.model,
api_key=api_key,
base_url=base_url,
)
logger.info(f"Embedding service initialized with {provider}/{model}")
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
def _create_provider(
self,

View File

@ -114,9 +114,11 @@ class ChromaVectorStore(VectorStore):
self,
collection_name: str,
persist_directory: Optional[str] = None,
embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增embedding 配置
):
self.collection_name = collection_name
self.persist_directory = persist_directory
self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置
self._client = None
self._collection = None
@ -136,9 +138,19 @@ class ChromaVectorStore(VectorStore):
settings=Settings(anonymized_telemetry=False),
)
# 🔥 构建 collection 元数据,包含 embedding 配置
collection_metadata = {"hnsw:space": "cosine"}
if self.embedding_config:
# 在元数据中记录 embedding 配置
collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai")
collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small")
collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536)
if self.embedding_config.get("base_url"):
collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url")
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"},
metadata=collection_metadata,
)
logger.info(f"Chroma collection '{self.collection_name}' initialized")
@ -146,6 +158,24 @@ class ChromaVectorStore(VectorStore):
except ImportError:
raise ImportError("chromadb is required. Install with: pip install chromadb")
def get_embedding_config(self) -> Dict[str, Any]:
"""
🔥 获取 collection embedding 配置
Returns:
包含 provider, model, dimension, base_url 的字典
"""
if not self._collection:
return {}
metadata = self._collection.metadata or {}
return {
"provider": metadata.get("embedding_provider"),
"model": metadata.get("embedding_model"),
"dimension": metadata.get("embedding_dimension"),
"base_url": metadata.get("embedding_base_url"),
}
async def add_documents(
self,
ids: List[str],
@ -335,6 +365,14 @@ class CodeIndexer:
self.embedding_service = embedding_service or EmbeddingService()
self.splitter = splitter or CodeSplitter()
# 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据
embedding_config = {
"provider": getattr(self.embedding_service, 'provider', 'openai'),
"model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'),
"dimension": getattr(self.embedding_service, 'dimension', 1536),
"base_url": getattr(self.embedding_service, 'base_url', None),
}
# 创建向量存储
if vector_store:
self.vector_store = vector_store
@ -343,6 +381,7 @@ class CodeIndexer:
self.vector_store = ChromaVectorStore(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_config=embedding_config, # 🔥 传递 embedding 配置
)
except ImportError:
logger.warning("Chroma not available, using in-memory store")

View File

@ -4,6 +4,7 @@
"""
import re
import asyncio
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
@ -75,6 +76,10 @@ class CodeRetriever:
"""
代码检索器
支持语义检索关键字检索和混合检索
🔥 自动兼容不同维度的向量
- 查询时自动检测 collection embedding 配置
- 动态创建对应的 embedding 服务
"""
def __init__(
@ -83,18 +88,22 @@ class CodeRetriever:
embedding_service: Optional[EmbeddingService] = None,
vector_store: Optional[VectorStore] = None,
persist_directory: Optional[str] = None,
api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务
):
"""
初始化检索器
Args:
collection_name: 向量集合名称
embedding_service: 嵌入服务
embedding_service: 嵌入服务可选会根据 collection 配置自动创建
vector_store: 向量存储
persist_directory: 持久化目录
api_key: API Key用于动态创建 embedding 服务
"""
self.collection_name = collection_name
self.embedding_service = embedding_service or EmbeddingService()
self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务
self.embedding_service = embedding_service # 实际使用的 embedding 服务
self._api_key = api_key
# 创建向量存储
if vector_store:
@ -112,11 +121,121 @@ class CodeRetriever:
self._initialized = False
async def initialize(self):
"""初始化检索器"""
if not self._initialized:
"""初始化检索器,自动检测并适配 collection 的 embedding 配置"""
if self._initialized:
return
await self.vector_store.initialize()
# 🔥 自动检测 collection 的 embedding 配置
if hasattr(self.vector_store, 'get_embedding_config'):
stored_config = self.vector_store.get_embedding_config()
stored_provider = stored_config.get("provider")
stored_model = stored_config.get("model")
stored_dimension = stored_config.get("dimension")
stored_base_url = stored_config.get("base_url")
# 🔥 如果没有存储的配置(旧的 collection尝试通过维度推断
if not stored_provider or not stored_model:
inferred = await self._infer_embedding_config_from_dimension()
if inferred:
stored_provider = inferred.get("provider")
stored_model = inferred.get("model")
stored_dimension = inferred.get("dimension")
logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}")
if stored_provider and stored_model:
# 检查是否需要使用不同的 embedding 服务
current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None
current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None
if current_provider != stored_provider or current_model != stored_model:
logger.info(
f"🔄 Collection 使用的 embedding 配置与当前不同: "
f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs "
f"{current_provider}/{current_model}"
)
logger.info(f"🔄 自动切换到 collection 的 embedding 配置")
# 动态创建对应的 embedding 服务
api_key = self._api_key
if not api_key and self._provided_embedding_service:
api_key = getattr(self._provided_embedding_service, 'api_key', None)
self.embedding_service = EmbeddingService(
provider=stored_provider,
model=stored_model,
api_key=api_key,
base_url=stored_base_url,
)
logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}")
# 如果仍然没有 embedding 服务,创建默认的
if not self.embedding_service:
self.embedding_service = EmbeddingService()
self._initialized = True
async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]:
"""
🔥 从向量维度推断 embedding 配置用于处理旧的 collection
Returns:
推断的 embedding 配置如果无法推断则返回 None
"""
try:
# 获取一个样本向量来检查维度
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
count = await self.vector_store.get_count()
if count > 0:
sample = await asyncio.to_thread(
self.vector_store._collection.peek,
limit=1
)
embeddings = sample.get("embeddings")
if embeddings is not None and len(embeddings) > 0:
dim = len(embeddings[0])
# 🔥 根据维度推断模型(优先选择常用模型)
dimension_mapping = {
# OpenAI 系列
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072},
# HuggingFace 系列
1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024},
384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384},
# Ollama 系列
768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768},
# Jina 系列
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
# Cohere 系列
# 1024 已被 HuggingFace 占用Cohere 维度相同时会默认使用 HuggingFace
}
inferred = dimension_mapping.get(dim)
if inferred:
logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}")
return inferred
except Exception as e:
logger.warning(f"无法推断 embedding 配置: {e}")
return None
def get_collection_embedding_config(self) -> Dict[str, Any]:
"""
获取 collection 存储的 embedding 配置
Returns:
包含 provider, model, dimension, base_url 的字典
"""
if hasattr(self.vector_store, 'get_embedding_config'):
return self.vector_store.get_embedding_config()
return {}
async def retrieve(
self,
query: str,