feat: Enhance embedding service configuration and add CI pipeline retry logic for dimension mismatches with optional index rebuild.
This commit is contained in:
parent
df8796e6e3
commit
4aea8ee7a9
|
|
@ -205,7 +205,19 @@ class CIService:
|
||||||
await self._post_gitea_comment(repo, issue.get("number"), msg)
|
await self._post_gitea_comment(repo, issue.get("number"), msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
context_results = await retriever.retrieve(query, top_k=5)
|
try:
|
||||||
|
context_results = await retriever.retrieve(query, top_k=5)
|
||||||
|
except Exception as e:
|
||||||
|
# Check for Chroma dimension mismatch
|
||||||
|
if "dimension" in str(e).lower():
|
||||||
|
logger.warning(f"Dimension mismatch detected for project {project.id}. Rebuilding index...")
|
||||||
|
await self._ensure_indexed(project, repo, branch, force_rebuild=True)
|
||||||
|
# Retry once
|
||||||
|
context_results = await retriever.retrieve(query, top_k=5)
|
||||||
|
else:
|
||||||
|
logger.error(f"Retrieval error: {e}")
|
||||||
|
context_results = []
|
||||||
|
|
||||||
repo_context = "\n".join([r.to_context_string() for r in context_results])
|
repo_context = "\n".join([r.to_context_string() for r in context_results])
|
||||||
|
|
||||||
# 4. Build Prompt
|
# 4. Build Prompt
|
||||||
|
|
@ -281,7 +293,7 @@ class CIService:
|
||||||
|
|
||||||
return project
|
return project
|
||||||
|
|
||||||
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str) -> Optional[str]:
|
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Syncs the repository and ensures it is indexed.
|
Syncs the repository and ensures it is indexed.
|
||||||
Returns the local path if successful.
|
Returns the local path if successful.
|
||||||
|
|
@ -295,15 +307,18 @@ class CIService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 2. Incremental Indexing
|
# 2. Incremental or Full Indexing
|
||||||
indexer = CodeIndexer(
|
indexer = CodeIndexer(
|
||||||
collection_name=f"ci_{project.id}",
|
collection_name=f"ci_{project.id}",
|
||||||
persist_directory=str(CI_VECTOR_DB_DIR / project.id)
|
persist_directory=str(CI_VECTOR_DB_DIR / project.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
update_mode = IndexUpdateMode.FULL if force_rebuild else IndexUpdateMode.INCREMENTAL
|
||||||
|
|
||||||
# Iterate over the generator to execute indexing
|
# Iterate over the generator to execute indexing
|
||||||
async for progress in indexer.smart_index_directory(
|
async for progress in indexer.smart_index_directory(
|
||||||
directory=repo_path,
|
directory=repo_path,
|
||||||
update_mode=IndexUpdateMode.INCREMENTAL
|
update_mode=update_mode
|
||||||
):
|
):
|
||||||
# Log progress occasionally
|
# Log progress occasionally
|
||||||
if progress.total_files > 0 and progress.processed_files % 20 == 0:
|
if progress.total_files > 0 and progress.processed_files % 20 == 0:
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,18 @@ class OpenAIEmbedding(EmbeddingProvider):
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
model: str = "text-embedding-3-small",
|
model: str = "text-embedding-3-small",
|
||||||
):
|
):
|
||||||
self.api_key = api_key or settings.LLM_API_KEY
|
# 优先使用显性参数,其次使用 EMBEDDING_API_KEY,最后使用 LLM_API_KEY
|
||||||
self.base_url = base_url or "https://api.openai.com/v1"
|
self.api_key = (
|
||||||
|
api_key
|
||||||
|
or getattr(settings, "EMBEDDING_API_KEY", None)
|
||||||
|
or settings.LLM_API_KEY
|
||||||
|
)
|
||||||
|
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL,最后使用 OpenAI 默认地址
|
||||||
|
self.base_url = (
|
||||||
|
base_url
|
||||||
|
or getattr(settings, "EMBEDDING_BASE_URL", None)
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
self.model = model
|
self.model = model
|
||||||
self._dimension = self.MODELS.get(model, 1536)
|
self._dimension = self.MODELS.get(model, 1536)
|
||||||
|
|
||||||
|
|
@ -593,15 +603,15 @@ class EmbeddingService:
|
||||||
# 确定提供商(保存原始值用于属性访问)
|
# 确定提供商(保存原始值用于属性访问)
|
||||||
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||||
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||||
self.api_key = api_key
|
self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
|
||||||
self.base_url = base_url
|
self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
|
||||||
|
|
||||||
# 创建提供商实例
|
# 创建提供商实例
|
||||||
self._provider = self._create_provider(
|
self._provider = self._create_provider(
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
api_key=api_key,
|
api_key=self.api_key,
|
||||||
base_url=base_url,
|
base_url=self.base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ from .embeddings import EmbeddingService
|
||||||
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
|
from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore
|
||||||
from .splitter import CodeChunk, ChunkType
|
from .splitter import CodeChunk, ChunkType
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -196,7 +198,19 @@ class CodeRetriever:
|
||||||
if embeddings is not None and len(embeddings) > 0:
|
if embeddings is not None and len(embeddings) > 0:
|
||||||
dim = len(embeddings[0])
|
dim = len(embeddings[0])
|
||||||
|
|
||||||
# 🔥 根据维度推断模型(优先选择常用模型)
|
# 🔥 1. Check if the current provider supports this dimension
|
||||||
|
current_provider_name = getattr(self.embedding_service, 'provider', settings.EMBEDDING_PROVIDER)
|
||||||
|
current_model_name = getattr(self.embedding_service, 'model', settings.EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
# Use a temporary service to check dimension if needed, or just trust current settings if dimension matches
|
||||||
|
if hasattr(self.embedding_service, 'dimension') and self.embedding_service.dimension == dim:
|
||||||
|
return {
|
||||||
|
"provider": current_provider_name,
|
||||||
|
"model": current_model_name,
|
||||||
|
"dimension": dim
|
||||||
|
}
|
||||||
|
|
||||||
|
# 🔥 2. Fallback to hardcoded mapping
|
||||||
dimension_mapping = {
|
dimension_mapping = {
|
||||||
# OpenAI 系列
|
# OpenAI 系列
|
||||||
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
|
1536: {"provider": "openai", "model": "text-embedding-3-small", "dimension": 1536},
|
||||||
|
|
@ -211,9 +225,6 @@ class CodeRetriever:
|
||||||
|
|
||||||
# Jina 系列
|
# Jina 系列
|
||||||
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
|
512: {"provider": "jina", "model": "jina-embeddings-v2-small-en", "dimension": 512},
|
||||||
|
|
||||||
# Cohere 系列
|
|
||||||
# 1024 已被 HuggingFace 占用,Cohere 维度相同时会默认使用 HuggingFace
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inferred = dimension_mapping.get(dim)
|
inferred = dimension_mapping.get(dim)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue