feat: Enhance embedding service configuration and add CI pipeline retry logic for dimension mismatches with optional index rebuild.

This commit is contained in:
vinland100 2025-12-31 17:11:27 +08:00
parent df8796e6e3
commit 4aea8ee7a9
3 changed files with 50 additions and 14 deletions

View File

@ -205,7 +205,19 @@ class CIService:
await self._post_gitea_comment(repo, issue.get("number"), msg)
return
context_results = await retriever.retrieve(query, top_k=5)
try:
context_results = await retriever.retrieve(query, top_k=5)
except Exception as e:
# Check for Chroma dimension mismatch
if "dimension" in str(e).lower():
logger.warning(f"Dimension mismatch detected for project {project.id}. Rebuilding index...")
await self._ensure_indexed(project, repo, branch, force_rebuild=True)
# Retry once
context_results = await retriever.retrieve(query, top_k=5)
else:
logger.error(f"Retrieval error: {e}")
context_results = []
repo_context = "\n".join([r.to_context_string() for r in context_results])
# 4. Build Prompt
@ -281,7 +293,7 @@ class CIService:
return project
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str) -> Optional[str]:
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
"""
Syncs the repository and ensures it is indexed.
Returns the local path if successful.
@ -295,15 +307,18 @@ class CIService:
return None
try:
# 2. Incremental Indexing
# 2. Incremental or Full Indexing
indexer = CodeIndexer(
collection_name=f"ci_{project.id}",
persist_directory=str(CI_VECTOR_DB_DIR / project.id)
)
update_mode = IndexUpdateMode.FULL if force_rebuild else IndexUpdateMode.INCREMENTAL
# Iterate over the generator to execute indexing
async for progress in indexer.smart_index_directory(
directory=repo_path,
update_mode=IndexUpdateMode.INCREMENTAL
update_mode=update_mode
):
# Log progress occasionally
if progress.total_files > 0 and progress.processed_files % 20 == 0:

View File

@ -60,8 +60,18 @@ class OpenAIEmbedding(EmbeddingProvider):
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
self.api_key = api_key or settings.LLM_API_KEY
self.base_url = base_url or "https://api.openai.com/v1"
# 优先使用显性参数,其次使用 EMBEDDING_API_KEY最后使用 LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or settings.LLM_API_KEY
)
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL最后使用 OpenAI 默认地址
self.base_url = (
base_url
or getattr(settings, "EMBEDDING_BASE_URL", None)
or "https://api.openai.com/v1"
)
self.model = model
self._dimension = self.MODELS.get(model, 1536)
@ -593,15 +603,15 @@ class EmbeddingService:
# 确定提供商(保存原始值用于属性访问)
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key
self.base_url = base_url
self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
# 创建提供商实例
self._provider = self._create_provider(
provider=self.provider,
model=self.model,
api_key=api_key,
base_url=base_url,
api_key=self.api_key,
base_url=self.base_url,
)
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")

View File

@ -13,6 +13,8 @@ from .embeddings import EmbeddingService
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
from .splitter import CodeChunk, ChunkType
from app.core.config import settings
logger = logging.getLogger(__name__)
@ -196,7 +198,19 @@ class CodeRetriever:
if embeddings is not None and len(embeddings) > 0:
dim = len(embeddings[0])
# 🔥 根据维度推断模型(优先选择常用模型)
# 🔥 1. Check if the current provider supports this dimension
current_provider_name = getattr(self.embedding_service, 'provider', settings.EMBEDDING_PROVIDER)
current_model_name = getattr(self.embedding_service, 'model', settings.EMBEDDING_MODEL)
# Use a temporary service to check dimension if needed, or just trust current settings if dimension matches
if hasattr(self.embedding_service, 'dimension') and self.embedding_service.dimension == dim:
return {
"provider": current_provider_name,
"model": current_model_name,
"dimension": dim
}
# 🔥 2. Fallback to hardcoded mapping
dimension_mapping = {
# OpenAI 系列
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
@ -211,9 +225,6 @@ class CodeRetriever:
# Jina 系列
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
# Cohere 系列
# 1024 已被 HuggingFace 占用Cohere 维度相同时会默认使用 HuggingFace
}
inferred = dimension_mapping.get(dim)