feat: Enhance RAG embedding dimension handling, add Qwen3 support, and improve index rebuild logic

This commit is contained in:
vinland100 2026-01-04 11:28:58 +08:00
parent 6abede99f4
commit 23ac263d76
6 changed files with 107 additions and 27 deletions

View File

@ -90,6 +90,7 @@ class Settings(BaseSettings):
EMBEDDING_MODEL: str = "text-embedding-3-small"
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key留空则使用 LLM_API_KEY
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL留空使用提供商默认地址
EMBEDDING_DIMENSION: int = 0 # 嵌入模型维度0 表示自动检测或由代码逻辑根据模型确定)
# 向量数据库配置
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录

View File

@ -73,7 +73,7 @@ class CIService:
return
# 2. Sync Repository and Index
repo_path = await self._ensure_indexed(project, repo, branch)
repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number)
if not repo_path:
return
@ -186,7 +186,7 @@ class CIService:
# 2. Ensure Indexed (Important for first-time chat or if project auto-created)
branch = repo.get("default_branch", "main")
repo_path = await self._ensure_indexed(project, repo, branch)
repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=issue.get("number"))
if not repo_path:
logger.error("Failed to sync/index repository for chat")
return
@ -302,14 +302,13 @@ class CIService:
return project
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
async def _ensure_indexed(self, project: Project, repo_url: str, branch: str, pr_number: Optional[int] = None, force_rebuild: bool = False) -> Optional[str]:
"""
Syncs the repository and ensures it is indexed.
Returns the local path if successful.
"""
repo_url = repo.get("clone_url")
# 1. Prepare Repository (Clone/Pull)
repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN)
repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN, pr_number=pr_number)
if not repo_path:
logger.error(f"Failed to prepare repository for project {project.id}")
@ -336,15 +335,23 @@ class CIService:
logger.info(f"✅ Project {project.name} indexing complete.")
return repo_path
except Exception as e:
err_msg = str(e)
# Detect dimension mismatch or specific embedding API errors that might require a rebuild
should_rebuild = any(x in err_msg.lower() for x in ["dimension", "404", "401", "400", "invalid_model"])
if not force_rebuild and should_rebuild:
logger.warning(f"⚠️ Indexing error for project {project.id}: {e}. Triggering automatic full rebuild...")
return await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number, force_rebuild=True)
logger.error(f"Indexing error for project {project.id}: {e}")
return repo_path # Return path anyway, maybe some files are present
return None # Fail properly
async def _get_project_by_repo(self, repo_url: str) -> Optional[Project]:
stmt = select(Project).where(Project.repository_url == repo_url)
result = await self.db.execute(stmt)
return result.scalars().first()
async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str) -> str:
async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str, pr_number: Optional[int] = None) -> str:
"""
Clones or Updates the repository locally.
"""
@ -381,6 +388,13 @@ class CIService:
try:
# git fetch --all
subprocess.run(["git", "fetch", "--all"], cwd=target_dir, check=True)
if pr_number:
# Fetch PR ref specifically from base repo: refs/pull/ID/head
logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
else:
# git checkout branch
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
# git reset --hard origin/branch
@ -388,12 +402,20 @@ class CIService:
except Exception as e:
logger.error(f"Git update failed: {e}. Re-cloning...")
shutil.rmtree(target_dir) # Nuke and retry
return await self._prepare_repository(project, repo_url, branch, token)
return await self._prepare_repository(project, repo_url, branch, token, pr_number=pr_number)
else:
# Clone
logger.info(f"📥 Cloning repo to {target_dir}")
try:
subprocess.run(["git", "clone", "-b", branch, auth_url, str(target_dir)], check=True)
# Clone without -b first, then fetch and checkout
subprocess.run(["git", "clone", auth_url, str(target_dir)], check=True)
if pr_number:
logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
else:
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
except Exception as e:
logger.error(f"Git clone failed: {e}")
raise e
@ -412,6 +434,9 @@ class CIService:
resp = await client.get(api_url, headers=headers)
if resp.status_code == 200:
return resp.text
elif resp.status_code == 403:
logger.error(f"❌ Failed to fetch diff: 403 Forbidden. This usually means the GITEA_BOT_TOKEN lacks 'read:repository' scope. Response: {resp.text[:200]}")
return ""
else:
logger.error(f"Failed to fetch diff: {resp.status_code} - {resp.text[:200]}")
return ""

View File

@ -52,6 +52,7 @@ class OpenAIEmbedding(EmbeddingProvider):
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
def __init__(
@ -73,7 +74,13 @@ class OpenAIEmbedding(EmbeddingProvider):
or "https://api.openai.com/v1"
)
self.model = model
self._dimension = self.MODELS.get(model, 1536)
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
@ -130,6 +137,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
# 最新的 GA API 版本
@ -144,7 +152,13 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
self.api_key = api_key
self.base_url = base_url or "https://your-resource.openai.azure.com"
self.model = model
self._dimension = self.MODELS.get(model, 1536)
# 优先使用项目配置中的显式维度
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
@ -481,6 +495,7 @@ class QwenEmbedding(EmbeddingProvider):
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
"Qwen3-Embedding-4B": 2560,
}
def __init__(
@ -505,7 +520,14 @@ class QwenEmbedding(EmbeddingProvider):
# DashScope 兼容 OpenAI 的 embeddings 端点
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1024)
)
@property
def dimension(self) -> int:

View File

@ -763,20 +763,22 @@ class CodeIndexer:
Returns:
(needs_rebuild, reason) - 是否需要重建及原因
"""
if self._initialized and not force_rebuild:
return self._needs_rebuild, self._rebuild_reason
# 如果 force_rebuild 为真,或者尚未初始化且 _needs_rebuild 已通过某种方式置为真
should_recreate = force_rebuild or (not self._initialized and self._needs_rebuild)
# 先初始化 vector_store不强制重建只是获取现有 collection
await self.vector_store.initialize(force_recreate=False)
# 先初始化 vector_store
await self.vector_store.initialize(force_recreate=should_recreate)
if should_recreate:
self._needs_rebuild = True
self._rebuild_reason = "强制重建"
self._initialized = True
return True, self._rebuild_reason
# 检查是否需要重建
self._needs_rebuild, self._rebuild_reason = await self._check_rebuild_needed()
if force_rebuild:
self._needs_rebuild = True
self._rebuild_reason = "用户强制重建"
# 如果需要重建,重新初始化 vector_store强制重建
# 如果自动检测到需要重建,则再次初始化并强制重建
if self._needs_rebuild:
logger.info(f"🔄 需要重建索引: {self._rebuild_reason}")
await self.vector_store.initialize(force_recreate=True)
@ -819,11 +821,30 @@ class CodeIndexer:
# 检查维度
stored_dimension = stored_config.get("dimension")
current_dimension = self.embedding_config.get("dimension")
if stored_dimension and current_dimension and stored_dimension != current_dimension:
# 🔥 如果 metadata 中没有维度,尝试从 sample 中检测
if not stored_dimension:
stored_dimension = await self._detect_actual_dimension()
if stored_dimension:
logger.debug(f"🔍 从 sample 检测到实际维度: {stored_dimension}")
if stored_dimension and current_dimension and int(stored_dimension) != int(current_dimension):
return True, f"Embedding 维度变更: {stored_dimension} -> {current_dimension}"
return False, ""
async def _detect_actual_dimension(self) -> Optional[int]:
"""从既存向量中检测实际维度"""
try:
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
peek = await asyncio.to_thread(self.vector_store._collection.peek, limit=1)
embeddings = peek.get("embeddings")
if embeddings and len(embeddings) > 0:
return len(embeddings[0])
except Exception:
pass
return None
async def get_index_status(self) -> IndexStatus:
"""获取索引状态"""
await self.initialize()
@ -873,7 +894,7 @@ class CodeIndexer:
索引进度
"""
# 初始化并检查是否需要重建
needs_rebuild, rebuild_reason = await self.initialize()
needs_rebuild, rebuild_reason = await self.initialize(force_rebuild=(update_mode == IndexUpdateMode.FULL))
progress = IndexingProgress()
exclude_patterns = exclude_patterns or []

View File

@ -164,6 +164,12 @@ class CodeRetriever:
if not api_key and self._provided_embedding_service:
api_key = getattr(self._provided_embedding_service, 'api_key', None)
# 🔥 重要:如果用户显式指定了维度,且与存储的维度不匹配,则不应自动切换(会导致报错)
explicit_dim = getattr(settings, "EMBEDDING_DIMENSION", 0)
if explicit_dim > 0 and explicit_dim != stored_dimension:
logger.warning(f"⚠️ Collection 维度 ({stored_dimension}) 与显式指定的维度 ({explicit_dim}) 不匹配,跳过自动切换以避免错误。")
return
self.embedding_service = EmbeddingService(
provider=stored_provider,
model=stored_model,

View File

@ -132,6 +132,9 @@ EMBEDDING_PROVIDER=openai
# Ollama: nomic-embed-text, mxbai-embed-large
EMBEDDING_MODEL=text-embedding-3-small
# 嵌入模型维度
EMBEDDING_DIMENSION=2560
# 嵌入模型 API Key留空则使用 LLM_API_KEY
EMBEDDING_API_KEY=
@ -219,5 +222,7 @@ ZIP_STORAGE_PATH=./uploads/zip_files
# zh-CN: 中文
# en-US: 英文
OUTPUT_LANGUAGE=zh-CN
# Gitea 配置
GITEA_HOST_URL=
GITEA_BOT_TOKEN=