feat(retriever): 添加自动适配不同 embedding 配置的功能
支持自动检测 collection 的 embedding 配置并动态创建对应的 embedding 服务 新增从向量维度推断配置的功能,兼容旧的 collection
This commit is contained in:
parent
3176c35817
commit
17889dceee
|
|
@ -644,10 +644,12 @@ async def _initialize_tools(
|
|||
collection_name = f"project_{project_id}" if project_id else "default_project"
|
||||
|
||||
# 创建 CodeRetriever(用于搜索)
|
||||
# 🔥 传递 api_key,用于自动适配 collection 的 embedding 配置
|
||||
retriever = CodeRetriever(
|
||||
collection_name=collection_name,
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=settings.VECTOR_DB_PATH,
|
||||
api_key=embedding_api_key, # 🔥 传递 api_key 以支持自动切换 embedding
|
||||
)
|
||||
|
||||
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
|
||||
|
|
|
|||
|
|
@ -487,7 +487,7 @@ class EmbeddingService:
|
|||
):
|
||||
"""
|
||||
初始化嵌入服务
|
||||
|
||||
|
||||
Args:
|
||||
provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina)
|
||||
model: 模型名称
|
||||
|
|
@ -497,20 +497,22 @@ class EmbeddingService:
|
|||
"""
|
||||
self.cache_enabled = cache_enabled
|
||||
self._cache: Dict[str, List[float]] = {}
|
||||
|
||||
# 确定提供商
|
||||
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
|
||||
|
||||
# 确定提供商(保存原始值用于属性访问)
|
||||
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
# 创建提供商实例
|
||||
self._provider = self._create_provider(
|
||||
provider=provider,
|
||||
model=model,
|
||||
provider=self.provider,
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Embedding service initialized with {provider}/{model}")
|
||||
|
||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
||||
|
||||
def _create_provider(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -109,23 +109,25 @@ class VectorStore:
|
|||
|
||||
class ChromaVectorStore(VectorStore):
|
||||
"""Chroma 向量存储"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
persist_directory: Optional[str] = None,
|
||||
embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增:embedding 配置
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.persist_directory = persist_directory
|
||||
self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化 Chroma"""
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
|
||||
if self.persist_directory:
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=self.persist_directory,
|
||||
|
|
@ -135,16 +137,44 @@ class ChromaVectorStore(VectorStore):
|
|||
self._client = chromadb.Client(
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
|
||||
|
||||
# 🔥 构建 collection 元数据,包含 embedding 配置
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
if self.embedding_config:
|
||||
# 在元数据中记录 embedding 配置
|
||||
collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai")
|
||||
collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small")
|
||||
collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536)
|
||||
if self.embedding_config.get("base_url"):
|
||||
collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url")
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Chroma collection '{self.collection_name}' initialized")
|
||||
|
||||
|
||||
except ImportError:
|
||||
raise ImportError("chromadb is required. Install with: pip install chromadb")
|
||||
|
||||
def get_embedding_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
🔥 获取 collection 的 embedding 配置
|
||||
|
||||
Returns:
|
||||
包含 provider, model, dimension, base_url 的字典
|
||||
"""
|
||||
if not self._collection:
|
||||
return {}
|
||||
|
||||
metadata = self._collection.metadata or {}
|
||||
return {
|
||||
"provider": metadata.get("embedding_provider"),
|
||||
"model": metadata.get("embedding_model"),
|
||||
"dimension": metadata.get("embedding_dimension"),
|
||||
"base_url": metadata.get("embedding_base_url"),
|
||||
}
|
||||
|
||||
async def add_documents(
|
||||
self,
|
||||
|
|
@ -312,7 +342,7 @@ class CodeIndexer:
|
|||
代码索引器
|
||||
将代码文件分块、嵌入并索引到向量数据库
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
@ -323,7 +353,7 @@ class CodeIndexer:
|
|||
):
|
||||
"""
|
||||
初始化索引器
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: 向量集合名称
|
||||
embedding_service: 嵌入服务
|
||||
|
|
@ -334,7 +364,15 @@ class CodeIndexer:
|
|||
self.collection_name = collection_name
|
||||
self.embedding_service = embedding_service or EmbeddingService()
|
||||
self.splitter = splitter or CodeSplitter()
|
||||
|
||||
|
||||
# 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据
|
||||
embedding_config = {
|
||||
"provider": getattr(self.embedding_service, 'provider', 'openai'),
|
||||
"model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'),
|
||||
"dimension": getattr(self.embedding_service, 'dimension', 1536),
|
||||
"base_url": getattr(self.embedding_service, 'base_url', None),
|
||||
}
|
||||
|
||||
# 创建向量存储
|
||||
if vector_store:
|
||||
self.vector_store = vector_store
|
||||
|
|
@ -343,11 +381,12 @@ class CodeIndexer:
|
|||
self.vector_store = ChromaVectorStore(
|
||||
collection_name=collection_name,
|
||||
persist_directory=persist_directory,
|
||||
embedding_config=embedding_config, # 🔥 传递 embedding 配置
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("Chroma not available, using in-memory store")
|
||||
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
|
||||
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -75,27 +76,35 @@ class CodeRetriever:
|
|||
"""
|
||||
代码检索器
|
||||
支持语义检索、关键字检索和混合检索
|
||||
|
||||
🔥 自动兼容不同维度的向量:
|
||||
- 查询时自动检测 collection 的 embedding 配置
|
||||
- 动态创建对应的 embedding 服务
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embedding_service: Optional[EmbeddingService] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
persist_directory: Optional[str] = None,
|
||||
api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务
|
||||
):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: 向量集合名称
|
||||
embedding_service: 嵌入服务
|
||||
embedding_service: 嵌入服务(可选,会根据 collection 配置自动创建)
|
||||
vector_store: 向量存储
|
||||
persist_directory: 持久化目录
|
||||
api_key: API Key(用于动态创建 embedding 服务)
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.embedding_service = embedding_service or EmbeddingService()
|
||||
|
||||
self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务
|
||||
self.embedding_service = embedding_service # 实际使用的 embedding 服务
|
||||
self._api_key = api_key
|
||||
|
||||
# 创建向量存储
|
||||
if vector_store:
|
||||
self.vector_store = vector_store
|
||||
|
|
@ -108,14 +117,124 @@ class CodeRetriever:
|
|||
except ImportError:
|
||||
logger.warning("Chroma not available, using in-memory store")
|
||||
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
|
||||
|
||||
|
||||
self._initialized = False
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化检索器"""
|
||||
if not self._initialized:
|
||||
await self.vector_store.initialize()
|
||||
self._initialized = True
|
||||
"""初始化检索器,自动检测并适配 collection 的 embedding 配置"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
await self.vector_store.initialize()
|
||||
|
||||
# 🔥 自动检测 collection 的 embedding 配置
|
||||
if hasattr(self.vector_store, 'get_embedding_config'):
|
||||
stored_config = self.vector_store.get_embedding_config()
|
||||
stored_provider = stored_config.get("provider")
|
||||
stored_model = stored_config.get("model")
|
||||
stored_dimension = stored_config.get("dimension")
|
||||
stored_base_url = stored_config.get("base_url")
|
||||
|
||||
# 🔥 如果没有存储的配置(旧的 collection),尝试通过维度推断
|
||||
if not stored_provider or not stored_model:
|
||||
inferred = await self._infer_embedding_config_from_dimension()
|
||||
if inferred:
|
||||
stored_provider = inferred.get("provider")
|
||||
stored_model = inferred.get("model")
|
||||
stored_dimension = inferred.get("dimension")
|
||||
logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}")
|
||||
|
||||
if stored_provider and stored_model:
|
||||
# 检查是否需要使用不同的 embedding 服务
|
||||
current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None
|
||||
current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None
|
||||
|
||||
if current_provider != stored_provider or current_model != stored_model:
|
||||
logger.info(
|
||||
f"🔄 Collection 使用的 embedding 配置与当前不同: "
|
||||
f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs "
|
||||
f"{current_provider}/{current_model}"
|
||||
)
|
||||
logger.info(f"🔄 自动切换到 collection 的 embedding 配置")
|
||||
|
||||
# 动态创建对应的 embedding 服务
|
||||
api_key = self._api_key
|
||||
if not api_key and self._provided_embedding_service:
|
||||
api_key = getattr(self._provided_embedding_service, 'api_key', None)
|
||||
|
||||
self.embedding_service = EmbeddingService(
|
||||
provider=stored_provider,
|
||||
model=stored_model,
|
||||
api_key=api_key,
|
||||
base_url=stored_base_url,
|
||||
)
|
||||
logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}")
|
||||
|
||||
# 如果仍然没有 embedding 服务,创建默认的
|
||||
if not self.embedding_service:
|
||||
self.embedding_service = EmbeddingService()
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
🔥 从向量维度推断 embedding 配置(用于处理旧的 collection)
|
||||
|
||||
Returns:
|
||||
推断的 embedding 配置,如果无法推断则返回 None
|
||||
"""
|
||||
try:
|
||||
# 获取一个样本向量来检查维度
|
||||
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
|
||||
count = await self.vector_store.get_count()
|
||||
if count > 0:
|
||||
sample = await asyncio.to_thread(
|
||||
self.vector_store._collection.peek,
|
||||
limit=1
|
||||
)
|
||||
embeddings = sample.get("embeddings")
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
dim = len(embeddings[0])
|
||||
|
||||
# 🔥 根据维度推断模型(优先选择常用模型)
|
||||
dimension_mapping = {
|
||||
# OpenAI 系列
|
||||
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
|
||||
3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072},
|
||||
|
||||
# HuggingFace 系列
|
||||
1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024},
|
||||
384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384},
|
||||
|
||||
# Ollama 系列
|
||||
768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768},
|
||||
|
||||
# Jina 系列
|
||||
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
|
||||
|
||||
# Cohere 系列
|
||||
# 1024 已被 HuggingFace 占用,Cohere 维度相同时会默认使用 HuggingFace
|
||||
}
|
||||
|
||||
inferred = dimension_mapping.get(dim)
|
||||
if inferred:
|
||||
logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}")
|
||||
return inferred
|
||||
except Exception as e:
|
||||
logger.warning(f"无法推断 embedding 配置: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_collection_embedding_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 collection 存储的 embedding 配置
|
||||
|
||||
Returns:
|
||||
包含 provider, model, dimension, base_url 的字典
|
||||
"""
|
||||
if hasattr(self.vector_store, 'get_embedding_config'):
|
||||
return self.vector_store.get_embedding_config()
|
||||
return {}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue