feat: Enhance RAG embedding dimension handling, add Qwen3 support, and improve index rebuild logic

This commit is contained in:
vinland100 2026-01-04 11:28:58 +08:00
parent 6abede99f4
commit 23ac263d76
6 changed files with 107 additions and 27 deletions

View File

@ -90,6 +90,7 @@ class Settings(BaseSettings):
EMBEDDING_MODEL: str = "text-embedding-3-small" EMBEDDING_MODEL: str = "text-embedding-3-small"
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key留空则使用 LLM_API_KEY EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key留空则使用 LLM_API_KEY
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL留空使用提供商默认地址 EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL留空使用提供商默认地址
EMBEDDING_DIMENSION: int = 0 # 嵌入模型维度0 表示自动检测或由代码逻辑根据模型确定)
# 向量数据库配置 # 向量数据库配置
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录 VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录

View File

@ -73,7 +73,7 @@ class CIService:
return return
# 2. Sync Repository and Index # 2. Sync Repository and Index
repo_path = await self._ensure_indexed(project, repo, branch) repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number)
if not repo_path: if not repo_path:
return return
@ -186,7 +186,7 @@ class CIService:
# 2. Ensure Indexed (Important for first-time chat or if project auto-created) # 2. Ensure Indexed (Important for first-time chat or if project auto-created)
branch = repo.get("default_branch", "main") branch = repo.get("default_branch", "main")
repo_path = await self._ensure_indexed(project, repo, branch) repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=issue.get("number"))
if not repo_path: if not repo_path:
logger.error("Failed to sync/index repository for chat") logger.error("Failed to sync/index repository for chat")
return return
@ -302,14 +302,13 @@ class CIService:
return project return project
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]: async def _ensure_indexed(self, project: Project, repo_url: str, branch: str, pr_number: Optional[int] = None, force_rebuild: bool = False) -> Optional[str]:
""" """
Syncs the repository and ensures it is indexed. Syncs the repository and ensures it is indexed.
Returns the local path if successful. Returns the local path if successful.
""" """
repo_url = repo.get("clone_url")
# 1. Prepare Repository (Clone/Pull) # 1. Prepare Repository (Clone/Pull)
repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN) repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN, pr_number=pr_number)
if not repo_path: if not repo_path:
logger.error(f"Failed to prepare repository for project {project.id}") logger.error(f"Failed to prepare repository for project {project.id}")
@ -336,15 +335,23 @@ class CIService:
logger.info(f"✅ Project {project.name} indexing complete.") logger.info(f"✅ Project {project.name} indexing complete.")
return repo_path return repo_path
except Exception as e: except Exception as e:
err_msg = str(e)
# Detect dimension mismatch or specific embedding API errors that might require a rebuild
should_rebuild = any(x in err_msg.lower() for x in ["dimension", "404", "401", "400", "invalid_model"])
if not force_rebuild and should_rebuild:
logger.warning(f"⚠️ Indexing error for project {project.id}: {e}. Triggering automatic full rebuild...")
return await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number, force_rebuild=True)
logger.error(f"Indexing error for project {project.id}: {e}") logger.error(f"Indexing error for project {project.id}: {e}")
return repo_path # Return path anyway, maybe some files are present return None # Fail properly
async def _get_project_by_repo(self, repo_url: str) -> Optional[Project]: async def _get_project_by_repo(self, repo_url: str) -> Optional[Project]:
stmt = select(Project).where(Project.repository_url == repo_url) stmt = select(Project).where(Project.repository_url == repo_url)
result = await self.db.execute(stmt) result = await self.db.execute(stmt)
return result.scalars().first() return result.scalars().first()
async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str) -> str: async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str, pr_number: Optional[int] = None) -> str:
""" """
Clones or Updates the repository locally. Clones or Updates the repository locally.
""" """
@ -381,19 +388,34 @@ class CIService:
try: try:
# git fetch --all # git fetch --all
subprocess.run(["git", "fetch", "--all"], cwd=target_dir, check=True) subprocess.run(["git", "fetch", "--all"], cwd=target_dir, check=True)
# git checkout branch
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True) if pr_number:
# git reset --hard origin/branch # Fetch PR ref specifically from base repo: refs/pull/ID/head
subprocess.run(["git", "reset", "--hard", f"origin/{branch}"], cwd=target_dir, check=True) logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
else:
# git checkout branch
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
# git reset --hard origin/branch
subprocess.run(["git", "reset", "--hard", f"origin/{branch}"], cwd=target_dir, check=True)
except Exception as e: except Exception as e:
logger.error(f"Git update failed: {e}. Re-cloning...") logger.error(f"Git update failed: {e}. Re-cloning...")
shutil.rmtree(target_dir) # Nuke and retry shutil.rmtree(target_dir) # Nuke and retry
return await self._prepare_repository(project, repo_url, branch, token) return await self._prepare_repository(project, repo_url, branch, token, pr_number=pr_number)
else: else:
# Clone # Clone
logger.info(f"📥 Cloning repo to {target_dir}") logger.info(f"📥 Cloning repo to {target_dir}")
try: try:
subprocess.run(["git", "clone", "-b", branch, auth_url, str(target_dir)], check=True) # Clone without -b first, then fetch and checkout
subprocess.run(["git", "clone", auth_url, str(target_dir)], check=True)
if pr_number:
logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
else:
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
except Exception as e: except Exception as e:
logger.error(f"Git clone failed: {e}") logger.error(f"Git clone failed: {e}")
raise e raise e
@ -412,6 +434,9 @@ class CIService:
resp = await client.get(api_url, headers=headers) resp = await client.get(api_url, headers=headers)
if resp.status_code == 200: if resp.status_code == 200:
return resp.text return resp.text
elif resp.status_code == 403:
logger.error(f"❌ Failed to fetch diff: 403 Forbidden. This usually means the GITEA_BOT_TOKEN lacks 'read:repository' scope. Response: {resp.text[:200]}")
return ""
else: else:
logger.error(f"Failed to fetch diff: {resp.status_code} - {resp.text[:200]}") logger.error(f"Failed to fetch diff: {resp.status_code} - {resp.text[:200]}")
return "" return ""

View File

@ -52,6 +52,7 @@ class OpenAIEmbedding(EmbeddingProvider):
"text-embedding-3-small": 1536, "text-embedding-3-small": 1536,
"text-embedding-3-large": 3072, "text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536, "text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
} }
def __init__( def __init__(
@ -73,7 +74,13 @@ class OpenAIEmbedding(EmbeddingProvider):
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
self.model = model self.model = model
self._dimension = self.MODELS.get(model, 1536) # 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property @property
def dimension(self) -> int: def dimension(self) -> int:
@ -130,6 +137,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
"text-embedding-3-small": 1536, "text-embedding-3-small": 1536,
"text-embedding-3-large": 3072, "text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536, "text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
} }
# 最新的 GA API 版本 # 最新的 GA API 版本
@ -144,7 +152,13 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
self.api_key = api_key self.api_key = api_key
self.base_url = base_url or "https://your-resource.openai.azure.com" self.base_url = base_url or "https://your-resource.openai.azure.com"
self.model = model self.model = model
self._dimension = self.MODELS.get(model, 1536) # 优先使用项目配置中的显式维度
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property @property
def dimension(self) -> int: def dimension(self) -> int:
@ -481,6 +495,7 @@ class QwenEmbedding(EmbeddingProvider):
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536 "text-embedding-v2": 1536, # 支持维度: 1536
"Qwen3-Embedding-4B": 2560,
} }
def __init__( def __init__(
@ -505,7 +520,14 @@ class QwenEmbedding(EmbeddingProvider):
# DashScope 兼容 OpenAI 的 embeddings 端点 # DashScope 兼容 OpenAI 的 embeddings 端点
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = model self.model = model
self._dimension = self.MODELS.get(model, 1024)
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1024)
)
@property @property
def dimension(self) -> int: def dimension(self) -> int:

View File

@ -763,20 +763,22 @@ class CodeIndexer:
Returns: Returns:
(needs_rebuild, reason) - 是否需要重建及原因 (needs_rebuild, reason) - 是否需要重建及原因
""" """
if self._initialized and not force_rebuild: # 如果 force_rebuild 为真,或者尚未初始化且 _needs_rebuild 已通过某种方式置为真
return self._needs_rebuild, self._rebuild_reason should_recreate = force_rebuild or (not self._initialized and self._needs_rebuild)
# 先初始化 vector_store不强制重建只是获取现有 collection # 先初始化 vector_store
await self.vector_store.initialize(force_recreate=False) await self.vector_store.initialize(force_recreate=should_recreate)
if should_recreate:
self._needs_rebuild = True
self._rebuild_reason = "强制重建"
self._initialized = True
return True, self._rebuild_reason
# 检查是否需要重建 # 检查是否需要重建
self._needs_rebuild, self._rebuild_reason = await self._check_rebuild_needed() self._needs_rebuild, self._rebuild_reason = await self._check_rebuild_needed()
if force_rebuild: # 如果自动检测到需要重建,则再次初始化并强制重建
self._needs_rebuild = True
self._rebuild_reason = "用户强制重建"
# 如果需要重建,重新初始化 vector_store强制重建
if self._needs_rebuild: if self._needs_rebuild:
logger.info(f"🔄 需要重建索引: {self._rebuild_reason}") logger.info(f"🔄 需要重建索引: {self._rebuild_reason}")
await self.vector_store.initialize(force_recreate=True) await self.vector_store.initialize(force_recreate=True)
@ -819,11 +821,30 @@ class CodeIndexer:
# 检查维度 # 检查维度
stored_dimension = stored_config.get("dimension") stored_dimension = stored_config.get("dimension")
current_dimension = self.embedding_config.get("dimension") current_dimension = self.embedding_config.get("dimension")
if stored_dimension and current_dimension and stored_dimension != current_dimension:
# 🔥 如果 metadata 中没有维度,尝试从 sample 中检测
if not stored_dimension:
stored_dimension = await self._detect_actual_dimension()
if stored_dimension:
logger.debug(f"🔍 从 sample 检测到实际维度: {stored_dimension}")
if stored_dimension and current_dimension and int(stored_dimension) != int(current_dimension):
return True, f"Embedding 维度变更: {stored_dimension} -> {current_dimension}" return True, f"Embedding 维度变更: {stored_dimension} -> {current_dimension}"
return False, "" return False, ""
async def _detect_actual_dimension(self) -> Optional[int]:
"""从既存向量中检测实际维度"""
try:
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
peek = await asyncio.to_thread(self.vector_store._collection.peek, limit=1)
embeddings = peek.get("embeddings")
if embeddings and len(embeddings) > 0:
return len(embeddings[0])
except Exception:
pass
return None
async def get_index_status(self) -> IndexStatus: async def get_index_status(self) -> IndexStatus:
"""获取索引状态""" """获取索引状态"""
await self.initialize() await self.initialize()
@ -873,7 +894,7 @@ class CodeIndexer:
索引进度 索引进度
""" """
# 初始化并检查是否需要重建 # 初始化并检查是否需要重建
needs_rebuild, rebuild_reason = await self.initialize() needs_rebuild, rebuild_reason = await self.initialize(force_rebuild=(update_mode == IndexUpdateMode.FULL))
progress = IndexingProgress() progress = IndexingProgress()
exclude_patterns = exclude_patterns or [] exclude_patterns = exclude_patterns or []

View File

@ -164,6 +164,12 @@ class CodeRetriever:
if not api_key and self._provided_embedding_service: if not api_key and self._provided_embedding_service:
api_key = getattr(self._provided_embedding_service, 'api_key', None) api_key = getattr(self._provided_embedding_service, 'api_key', None)
# 🔥 重要:如果用户显式指定了维度,且与存储的维度不匹配,则不应自动切换(会导致报错)
explicit_dim = getattr(settings, "EMBEDDING_DIMENSION", 0)
if explicit_dim > 0 and explicit_dim != stored_dimension:
logger.warning(f"⚠️ Collection 维度 ({stored_dimension}) 与显式指定的维度 ({explicit_dim}) 不匹配,跳过自动切换以避免错误。")
return
self.embedding_service = EmbeddingService( self.embedding_service = EmbeddingService(
provider=stored_provider, provider=stored_provider,
model=stored_model, model=stored_model,

View File

@ -132,6 +132,9 @@ EMBEDDING_PROVIDER=openai
# Ollama: nomic-embed-text, mxbai-embed-large # Ollama: nomic-embed-text, mxbai-embed-large
EMBEDDING_MODEL=text-embedding-3-small EMBEDDING_MODEL=text-embedding-3-small
# 嵌入模型维度
EMBEDDING_DIMENSION=2560
# 嵌入模型 API Key留空则使用 LLM_API_KEY # 嵌入模型 API Key留空则使用 LLM_API_KEY
EMBEDDING_API_KEY= EMBEDDING_API_KEY=
@ -219,5 +222,7 @@ ZIP_STORAGE_PATH=./uploads/zip_files
# zh-CN: 中文 # zh-CN: 中文
# en-US: 英文 # en-US: 英文
OUTPUT_LANGUAGE=zh-CN OUTPUT_LANGUAGE=zh-CN
# Gitea 配置
GITEA_HOST_URL= GITEA_HOST_URL=
GITEA_BOT_TOKEN= GITEA_BOT_TOKEN=