feat: Enhance embedding service configuration and add CI pipeline retry logic for dimension mismatches with optional index rebuild.
This commit is contained in:
parent
df8796e6e3
commit
4aea8ee7a9
|
|
@ -205,7 +205,19 @@ class CIService:
|
|||
await self._post_gitea_comment(repo, issue.get("number"), msg)
|
||||
return
|
||||
|
||||
try:
|
||||
context_results = await retriever.retrieve(query, top_k=5)
|
||||
except Exception as e:
|
||||
# Check for Chroma dimension mismatch
|
||||
if "dimension" in str(e).lower():
|
||||
logger.warning(f"Dimension mismatch detected for project {project.id}. Rebuilding index...")
|
||||
await self._ensure_indexed(project, repo, branch, force_rebuild=True)
|
||||
# Retry once
|
||||
context_results = await retriever.retrieve(query, top_k=5)
|
||||
else:
|
||||
logger.error(f"Retrieval error: {e}")
|
||||
context_results = []
|
||||
|
||||
repo_context = "\n".join([r.to_context_string() for r in context_results])
|
||||
|
||||
# 4. Build Prompt
|
||||
|
|
@ -281,7 +293,7 @@ class CIService:
|
|||
|
||||
return project
|
||||
|
||||
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str) -> Optional[str]:
|
||||
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Syncs the repository and ensures it is indexed.
|
||||
Returns the local path if successful.
|
||||
|
|
@ -295,15 +307,18 @@ class CIService:
|
|||
return None
|
||||
|
||||
try:
|
||||
# 2. Incremental Indexing
|
||||
# 2. Incremental or Full Indexing
|
||||
indexer = CodeIndexer(
|
||||
collection_name=f"ci_{project.id}",
|
||||
persist_directory=str(CI_VECTOR_DB_DIR / project.id)
|
||||
)
|
||||
|
||||
update_mode = IndexUpdateMode.FULL if force_rebuild else IndexUpdateMode.INCREMENTAL
|
||||
|
||||
# Iterate over the generator to execute indexing
|
||||
async for progress in indexer.smart_index_directory(
|
||||
directory=repo_path,
|
||||
update_mode=IndexUpdateMode.INCREMENTAL
|
||||
update_mode=update_mode
|
||||
):
|
||||
# Log progress occasionally
|
||||
if progress.total_files > 0 and progress.processed_files % 20 == 0:
|
||||
|
|
|
|||
|
|
@ -60,8 +60,18 @@ class OpenAIEmbedding(EmbeddingProvider):
|
|||
base_url: Optional[str] = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
):
|
||||
self.api_key = api_key or settings.LLM_API_KEY
|
||||
self.base_url = base_url or "https://api.openai.com/v1"
|
||||
# 优先使用显性参数,其次使用 EMBEDDING_API_KEY,最后使用 LLM_API_KEY
|
||||
self.api_key = (
|
||||
api_key
|
||||
or getattr(settings, "EMBEDDING_API_KEY", None)
|
||||
or settings.LLM_API_KEY
|
||||
)
|
||||
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL,最后使用 OpenAI 默认地址
|
||||
self.base_url = (
|
||||
base_url
|
||||
or getattr(settings, "EMBEDDING_BASE_URL", None)
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
self.model = model
|
||||
self._dimension = self.MODELS.get(model, 1536)
|
||||
|
||||
|
|
@ -593,15 +603,15 @@ class EmbeddingService:
|
|||
# 确定提供商(保存原始值用于属性访问)
|
||||
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
|
||||
self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
|
||||
|
||||
# 创建提供商实例
|
||||
self._provider = self._create_provider(
|
||||
provider=self.provider,
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from .embeddings import EmbeddingService
|
|||
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
|
||||
from .splitter import CodeChunk, ChunkType
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -196,7 +198,19 @@ class CodeRetriever:
|
|||
if embeddings is not None and len(embeddings) > 0:
|
||||
dim = len(embeddings[0])
|
||||
|
||||
# 🔥 根据维度推断模型(优先选择常用模型)
|
||||
# 🔥 1. Check if the current provider supports this dimension
|
||||
current_provider_name = getattr(self.embedding_service, 'provider', settings.EMBEDDING_PROVIDER)
|
||||
current_model_name = getattr(self.embedding_service, 'model', settings.EMBEDDING_MODEL)
|
||||
|
||||
# Use a temporary service to check dimension if needed, or just trust current settings if dimension matches
|
||||
if hasattr(self.embedding_service, 'dimension') and self.embedding_service.dimension == dim:
|
||||
return {
|
||||
"provider": current_provider_name,
|
||||
"model": current_model_name,
|
||||
"dimension": dim
|
||||
}
|
||||
|
||||
# 🔥 2. Fallback to hardcoded mapping
|
||||
dimension_mapping = {
|
||||
# OpenAI 系列
|
||||
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
|
||||
|
|
@ -211,9 +225,6 @@ class CodeRetriever:
|
|||
|
||||
# Jina 系列
|
||||
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
|
||||
|
||||
# Cohere 系列
|
||||
# 1024 已被 HuggingFace 占用,Cohere 维度相同时会默认使用 HuggingFace
|
||||
}
|
||||
|
||||
inferred = dimension_mapping.get(dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue