feat: Enhance embedding service configuration and add CI pipeline retry logic for dimension mismatches with optional index rebuild.

This commit is contained in:
vinland100 2025-12-31 17:11:27 +08:00
parent df8796e6e3
commit 4aea8ee7a9
3 changed files with 50 additions and 14 deletions

View File

@ -205,7 +205,19 @@ class CIService:
await self._post_gitea_comment(repo, issue.get("number"), msg) await self._post_gitea_comment(repo, issue.get("number"), msg)
return return
try:
context_results = await retriever.retrieve(query, top_k=5) context_results = await retriever.retrieve(query, top_k=5)
except Exception as e:
# Check for Chroma dimension mismatch
if "dimension" in str(e).lower():
logger.warning(f"Dimension mismatch detected for project {project.id}. Rebuilding index...")
await self._ensure_indexed(project, repo, branch, force_rebuild=True)
# Retry once
context_results = await retriever.retrieve(query, top_k=5)
else:
logger.error(f"Retrieval error: {e}")
context_results = []
repo_context = "\n".join([r.to_context_string() for r in context_results]) repo_context = "\n".join([r.to_context_string() for r in context_results])
# 4. Build Prompt # 4. Build Prompt
@ -281,7 +293,7 @@ class CIService:
return project return project
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str) -> Optional[str]: async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
""" """
Syncs the repository and ensures it is indexed. Syncs the repository and ensures it is indexed.
Returns the local path if successful. Returns the local path if successful.
@ -295,15 +307,18 @@ class CIService:
return None return None
try: try:
# 2. Incremental Indexing # 2. Incremental or Full Indexing
indexer = CodeIndexer( indexer = CodeIndexer(
collection_name=f"ci_{project.id}", collection_name=f"ci_{project.id}",
persist_directory=str(CI_VECTOR_DB_DIR / project.id) persist_directory=str(CI_VECTOR_DB_DIR / project.id)
) )
update_mode = IndexUpdateMode.FULL if force_rebuild else IndexUpdateMode.INCREMENTAL
# Iterate over the generator to execute indexing # Iterate over the generator to execute indexing
async for progress in indexer.smart_index_directory( async for progress in indexer.smart_index_directory(
directory=repo_path, directory=repo_path,
update_mode=IndexUpdateMode.INCREMENTAL update_mode=update_mode
): ):
# Log progress occasionally # Log progress occasionally
if progress.total_files > 0 and progress.processed_files % 20 == 0: if progress.total_files > 0 and progress.processed_files % 20 == 0:

View File

@ -60,8 +60,18 @@ class OpenAIEmbedding(EmbeddingProvider):
base_url: Optional[str] = None, base_url: Optional[str] = None,
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
): ):
self.api_key = api_key or settings.LLM_API_KEY # 优先使用显性参数,其次使用 EMBEDDING_API_KEY最后使用 LLM_API_KEY
self.base_url = base_url or "https://api.openai.com/v1" self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or settings.LLM_API_KEY
)
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL最后使用 OpenAI 默认地址
self.base_url = (
base_url
or getattr(settings, "EMBEDDING_BASE_URL", None)
or "https://api.openai.com/v1"
)
self.model = model self.model = model
self._dimension = self.MODELS.get(model, 1536) self._dimension = self.MODELS.get(model, 1536)
@ -593,15 +603,15 @@ class EmbeddingService:
# 确定提供商(保存原始值用于属性访问) # 确定提供商(保存原始值用于属性访问)
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
self.base_url = base_url self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
# 创建提供商实例 # 创建提供商实例
self._provider = self._create_provider( self._provider = self._create_provider(
provider=self.provider, provider=self.provider,
model=self.model, model=self.model,
api_key=api_key, api_key=self.api_key,
base_url=base_url, base_url=self.base_url,
) )
logger.info(f"Embedding service initialized with {self.provider}/{self.model}") logger.info(f"Embedding service initialized with {self.provider}/{self.model}")

View File

@ -13,6 +13,8 @@ from .embeddings import EmbeddingService
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
from .splitter import CodeChunk, ChunkType from .splitter import CodeChunk, ChunkType
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -196,7 +198,19 @@ class CodeRetriever:
if embeddings is not None and len(embeddings) > 0: if embeddings is not None and len(embeddings) > 0:
dim = len(embeddings[0]) dim = len(embeddings[0])
# 🔥 根据维度推断模型(优先选择常用模型) # 🔥 1. Check if the current provider supports this dimension
current_provider_name = getattr(self.embedding_service, 'provider', settings.EMBEDDING_PROVIDER)
current_model_name = getattr(self.embedding_service, 'model', settings.EMBEDDING_MODEL)
# Use a temporary service to check dimension if needed, or just trust current settings if dimension matches
if hasattr(self.embedding_service, 'dimension') and self.embedding_service.dimension == dim:
return {
"provider": current_provider_name,
"model": current_model_name,
"dimension": dim
}
# 🔥 2. Fallback to hardcoded mapping
dimension_mapping = { dimension_mapping = {
# OpenAI 系列 # OpenAI 系列
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536}, 1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
@ -211,9 +225,6 @@ class CodeRetriever:
# Jina 系列 # Jina 系列
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512}, 512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
# Cohere 系列
# 1024 已被 HuggingFace 占用Cohere 维度相同时会默认使用 HuggingFace
} }
inferred = dimension_mapping.get(dim) inferred = dimension_mapping.get(dim)