feat(retriever): 添加自动适配不同 embedding 配置的功能
支持自动检测 collection 的 embedding 配置并动态创建对应的 embedding 服务 新增从向量维度推断配置的功能,兼容旧的 collection
This commit is contained in:
parent
3176c35817
commit
17889dceee
|
|
@ -644,10 +644,12 @@ async def _initialize_tools(
|
||||||
collection_name = f"project_{project_id}" if project_id else "default_project"
|
collection_name = f"project_{project_id}" if project_id else "default_project"
|
||||||
|
|
||||||
# 创建 CodeRetriever(用于搜索)
|
# 创建 CodeRetriever(用于搜索)
|
||||||
|
# 🔥 传递 api_key,用于自动适配 collection 的 embedding 配置
|
||||||
retriever = CodeRetriever(
|
retriever = CodeRetriever(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding_service=embedding_service,
|
embedding_service=embedding_service,
|
||||||
persist_directory=settings.VECTOR_DB_PATH,
|
persist_directory=settings.VECTOR_DB_PATH,
|
||||||
|
api_key=embedding_api_key, # 🔥 传递 api_key 以支持自动切换 embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
|
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
|
||||||
|
|
|
||||||
|
|
@ -498,19 +498,21 @@ class EmbeddingService:
|
||||||
self.cache_enabled = cache_enabled
|
self.cache_enabled = cache_enabled
|
||||||
self._cache: Dict[str, List[float]] = {}
|
self._cache: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
# 确定提供商
|
# 确定提供商(保存原始值用于属性访问)
|
||||||
provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||||
model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
# 创建提供商实例
|
# 创建提供商实例
|
||||||
self._provider = self._create_provider(
|
self._provider = self._create_provider(
|
||||||
provider=provider,
|
provider=self.provider,
|
||||||
model=model,
|
model=self.model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Embedding service initialized with {provider}/{model}")
|
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
||||||
|
|
||||||
def _create_provider(
|
def _create_provider(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -114,9 +114,11 @@ class ChromaVectorStore(VectorStore):
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
persist_directory: Optional[str] = None,
|
persist_directory: Optional[str] = None,
|
||||||
|
embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增:embedding 配置
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.persist_directory = persist_directory
|
self.persist_directory = persist_directory
|
||||||
|
self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置
|
||||||
self._client = None
|
self._client = None
|
||||||
self._collection = None
|
self._collection = None
|
||||||
|
|
||||||
|
|
@ -136,9 +138,19 @@ class ChromaVectorStore(VectorStore):
|
||||||
settings=Settings(anonymized_telemetry=False),
|
settings=Settings(anonymized_telemetry=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 🔥 构建 collection 元数据,包含 embedding 配置
|
||||||
|
collection_metadata = {"hnsw:space": "cosine"}
|
||||||
|
if self.embedding_config:
|
||||||
|
# 在元数据中记录 embedding 配置
|
||||||
|
collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai")
|
||||||
|
collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small")
|
||||||
|
collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536)
|
||||||
|
if self.embedding_config.get("base_url"):
|
||||||
|
collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url")
|
||||||
|
|
||||||
self._collection = self._client.get_or_create_collection(
|
self._collection = self._client.get_or_create_collection(
|
||||||
name=self.collection_name,
|
name=self.collection_name,
|
||||||
metadata={"hnsw:space": "cosine"},
|
metadata=collection_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Chroma collection '{self.collection_name}' initialized")
|
logger.info(f"Chroma collection '{self.collection_name}' initialized")
|
||||||
|
|
@ -146,6 +158,24 @@ class ChromaVectorStore(VectorStore):
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("chromadb is required. Install with: pip install chromadb")
|
raise ImportError("chromadb is required. Install with: pip install chromadb")
|
||||||
|
|
||||||
|
def get_embedding_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
🔥 获取 collection 的 embedding 配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 provider, model, dimension, base_url 的字典
|
||||||
|
"""
|
||||||
|
if not self._collection:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
metadata = self._collection.metadata or {}
|
||||||
|
return {
|
||||||
|
"provider": metadata.get("embedding_provider"),
|
||||||
|
"model": metadata.get("embedding_model"),
|
||||||
|
"dimension": metadata.get("embedding_dimension"),
|
||||||
|
"base_url": metadata.get("embedding_base_url"),
|
||||||
|
}
|
||||||
|
|
||||||
async def add_documents(
|
async def add_documents(
|
||||||
self,
|
self,
|
||||||
ids: List[str],
|
ids: List[str],
|
||||||
|
|
@ -335,6 +365,14 @@ class CodeIndexer:
|
||||||
self.embedding_service = embedding_service or EmbeddingService()
|
self.embedding_service = embedding_service or EmbeddingService()
|
||||||
self.splitter = splitter or CodeSplitter()
|
self.splitter = splitter or CodeSplitter()
|
||||||
|
|
||||||
|
# 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据
|
||||||
|
embedding_config = {
|
||||||
|
"provider": getattr(self.embedding_service, 'provider', 'openai'),
|
||||||
|
"model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'),
|
||||||
|
"dimension": getattr(self.embedding_service, 'dimension', 1536),
|
||||||
|
"base_url": getattr(self.embedding_service, 'base_url', None),
|
||||||
|
}
|
||||||
|
|
||||||
# 创建向量存储
|
# 创建向量存储
|
||||||
if vector_store:
|
if vector_store:
|
||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
|
|
@ -343,6 +381,7 @@ class CodeIndexer:
|
||||||
self.vector_store = ChromaVectorStore(
|
self.vector_store = ChromaVectorStore(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
persist_directory=persist_directory,
|
persist_directory=persist_directory,
|
||||||
|
embedding_config=embedding_config, # 🔥 传递 embedding 配置
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("Chroma not available, using in-memory store")
|
logger.warning("Chroma not available, using in-memory store")
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -75,6 +76,10 @@ class CodeRetriever:
|
||||||
"""
|
"""
|
||||||
代码检索器
|
代码检索器
|
||||||
支持语义检索、关键字检索和混合检索
|
支持语义检索、关键字检索和混合检索
|
||||||
|
|
||||||
|
🔥 自动兼容不同维度的向量:
|
||||||
|
- 查询时自动检测 collection 的 embedding 配置
|
||||||
|
- 动态创建对应的 embedding 服务
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -83,18 +88,22 @@ class CodeRetriever:
|
||||||
embedding_service: Optional[EmbeddingService] = None,
|
embedding_service: Optional[EmbeddingService] = None,
|
||||||
vector_store: Optional[VectorStore] = None,
|
vector_store: Optional[VectorStore] = None,
|
||||||
persist_directory: Optional[str] = None,
|
persist_directory: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None, # 🔥 新增:用于动态创建 embedding 服务
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化检索器
|
初始化检索器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: 向量集合名称
|
collection_name: 向量集合名称
|
||||||
embedding_service: 嵌入服务
|
embedding_service: 嵌入服务(可选,会根据 collection 配置自动创建)
|
||||||
vector_store: 向量存储
|
vector_store: 向量存储
|
||||||
persist_directory: 持久化目录
|
persist_directory: 持久化目录
|
||||||
|
api_key: API Key(用于动态创建 embedding 服务)
|
||||||
"""
|
"""
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_service = embedding_service or EmbeddingService()
|
self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务
|
||||||
|
self.embedding_service = embedding_service # 实际使用的 embedding 服务
|
||||||
|
self._api_key = api_key
|
||||||
|
|
||||||
# 创建向量存储
|
# 创建向量存储
|
||||||
if vector_store:
|
if vector_store:
|
||||||
|
|
@ -112,11 +121,121 @@ class CodeRetriever:
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化检索器"""
|
"""初始化检索器,自动检测并适配 collection 的 embedding 配置"""
|
||||||
if not self._initialized:
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
await self.vector_store.initialize()
|
await self.vector_store.initialize()
|
||||||
|
|
||||||
|
# 🔥 自动检测 collection 的 embedding 配置
|
||||||
|
if hasattr(self.vector_store, 'get_embedding_config'):
|
||||||
|
stored_config = self.vector_store.get_embedding_config()
|
||||||
|
stored_provider = stored_config.get("provider")
|
||||||
|
stored_model = stored_config.get("model")
|
||||||
|
stored_dimension = stored_config.get("dimension")
|
||||||
|
stored_base_url = stored_config.get("base_url")
|
||||||
|
|
||||||
|
# 🔥 如果没有存储的配置(旧的 collection),尝试通过维度推断
|
||||||
|
if not stored_provider or not stored_model:
|
||||||
|
inferred = await self._infer_embedding_config_from_dimension()
|
||||||
|
if inferred:
|
||||||
|
stored_provider = inferred.get("provider")
|
||||||
|
stored_model = inferred.get("model")
|
||||||
|
stored_dimension = inferred.get("dimension")
|
||||||
|
logger.info(f"📊 从向量维度推断 embedding 配置: {stored_provider}/{stored_model}")
|
||||||
|
|
||||||
|
if stored_provider and stored_model:
|
||||||
|
# 检查是否需要使用不同的 embedding 服务
|
||||||
|
current_provider = getattr(self.embedding_service, 'provider', None) if self.embedding_service else None
|
||||||
|
current_model = getattr(self.embedding_service, 'model', None) if self.embedding_service else None
|
||||||
|
|
||||||
|
if current_provider != stored_provider or current_model != stored_model:
|
||||||
|
logger.info(
|
||||||
|
f"🔄 Collection 使用的 embedding 配置与当前不同: "
|
||||||
|
f"{stored_provider}/{stored_model} (维度: {stored_dimension}) vs "
|
||||||
|
f"{current_provider}/{current_model}"
|
||||||
|
)
|
||||||
|
logger.info(f"🔄 自动切换到 collection 的 embedding 配置")
|
||||||
|
|
||||||
|
# 动态创建对应的 embedding 服务
|
||||||
|
api_key = self._api_key
|
||||||
|
if not api_key and self._provided_embedding_service:
|
||||||
|
api_key = getattr(self._provided_embedding_service, 'api_key', None)
|
||||||
|
|
||||||
|
self.embedding_service = EmbeddingService(
|
||||||
|
provider=stored_provider,
|
||||||
|
model=stored_model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=stored_base_url,
|
||||||
|
)
|
||||||
|
logger.info(f"✅ 已切换到: {stored_provider}/{stored_model}")
|
||||||
|
|
||||||
|
# 如果仍然没有 embedding 服务,创建默认的
|
||||||
|
if not self.embedding_service:
|
||||||
|
self.embedding_service = EmbeddingService()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
|
async def _infer_embedding_config_from_dimension(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
🔥 从向量维度推断 embedding 配置(用于处理旧的 collection)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推断的 embedding 配置,如果无法推断则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取一个样本向量来检查维度
|
||||||
|
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
|
||||||
|
count = await self.vector_store.get_count()
|
||||||
|
if count > 0:
|
||||||
|
sample = await asyncio.to_thread(
|
||||||
|
self.vector_store._collection.peek,
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
embeddings = sample.get("embeddings")
|
||||||
|
if embeddings is not None and len(embeddings) > 0:
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
|
||||||
|
# 🔥 根据维度推断模型(优先选择常用模型)
|
||||||
|
dimension_mapping = {
|
||||||
|
# OpenAI 系列
|
||||||
|
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
|
||||||
|
3072: {"provider": "openai", "model": "text-embedding-3-large", "dimension": 3072},
|
||||||
|
|
||||||
|
# HuggingFace 系列
|
||||||
|
1024: {"provider": "huggingface", "model": "BAAI/bge-m3", "dimension": 1024},
|
||||||
|
384: {"provider": "huggingface", "model": "sentence-transformers/all-MiniLM-L6-v2", "dimension": 384},
|
||||||
|
|
||||||
|
# Ollama 系列
|
||||||
|
768: {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768},
|
||||||
|
|
||||||
|
# Jina 系列
|
||||||
|
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
|
||||||
|
|
||||||
|
# Cohere 系列
|
||||||
|
# 1024 已被 HuggingFace 占用,Cohere 维度相同时会默认使用 HuggingFace
|
||||||
|
}
|
||||||
|
|
||||||
|
inferred = dimension_mapping.get(dim)
|
||||||
|
if inferred:
|
||||||
|
logger.info(f"📊 检测到向量维度 {dim},推断为: {inferred['provider']}/{inferred['model']}")
|
||||||
|
return inferred
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法推断 embedding 配置: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_collection_embedding_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取 collection 存储的 embedding 配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 provider, model, dimension, base_url 的字典
|
||||||
|
"""
|
||||||
|
if hasattr(self.vector_store, 'get_embedding_config'):
|
||||||
|
return self.vector_store.get_embedding_config()
|
||||||
|
return {}
|
||||||
|
|
||||||
async def retrieve(
|
async def retrieve(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue