feat: Enhance RAG embedding dimension handling, add Qwen3 support, and improve index rebuild logic
This commit is contained in:
parent
6abede99f4
commit
23ac263d76
|
|
@ -90,6 +90,7 @@ class Settings(BaseSettings):
|
|||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key(留空则使用 LLM_API_KEY)
|
||||
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL(留空使用提供商默认地址)
|
||||
EMBEDDING_DIMENSION: int = 0 # 嵌入模型维度(0 表示自动检测或由代码逻辑根据模型确定)
|
||||
|
||||
# 向量数据库配置
|
||||
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class CIService:
|
|||
return
|
||||
|
||||
# 2. Sync Repository and Index
|
||||
repo_path = await self._ensure_indexed(project, repo, branch)
|
||||
repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number)
|
||||
if not repo_path:
|
||||
return
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ class CIService:
|
|||
|
||||
# 2. Ensure Indexed (Important for first-time chat or if project auto-created)
|
||||
branch = repo.get("default_branch", "main")
|
||||
repo_path = await self._ensure_indexed(project, repo, branch)
|
||||
repo_path = await self._ensure_indexed(project, repo_url, branch, pr_number=issue.get("number"))
|
||||
if not repo_path:
|
||||
logger.error("Failed to sync/index repository for chat")
|
||||
return
|
||||
|
|
@ -302,14 +302,13 @@ class CIService:
|
|||
|
||||
return project
|
||||
|
||||
async def _ensure_indexed(self, project: Project, repo: Dict, branch: str, force_rebuild: bool = False) -> Optional[str]:
|
||||
async def _ensure_indexed(self, project: Project, repo_url: str, branch: str, pr_number: Optional[int] = None, force_rebuild: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Syncs the repository and ensures it is indexed.
|
||||
Returns the local path if successful.
|
||||
"""
|
||||
repo_url = repo.get("clone_url")
|
||||
# 1. Prepare Repository (Clone/Pull)
|
||||
repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN)
|
||||
repo_path = await self._prepare_repository(project, repo_url, branch, settings.GITEA_BOT_TOKEN, pr_number=pr_number)
|
||||
|
||||
if not repo_path:
|
||||
logger.error(f"Failed to prepare repository for project {project.id}")
|
||||
|
|
@ -336,15 +335,23 @@ class CIService:
|
|||
logger.info(f"✅ Project {project.name} indexing complete.")
|
||||
return repo_path
|
||||
except Exception as e:
|
||||
err_msg = str(e)
|
||||
# Detect dimension mismatch or specific embedding API errors that might require a rebuild
|
||||
should_rebuild = any(x in err_msg.lower() for x in ["dimension", "404", "401", "400", "invalid_model"])
|
||||
|
||||
if not force_rebuild and should_rebuild:
|
||||
logger.warning(f"⚠️ Indexing error for project {project.id}: {e}. Triggering automatic full rebuild...")
|
||||
return await self._ensure_indexed(project, repo_url, branch, pr_number=pr_number, force_rebuild=True)
|
||||
|
||||
logger.error(f"Indexing error for project {project.id}: {e}")
|
||||
return repo_path # Return path anyway, maybe some files are present
|
||||
return None # Fail properly
|
||||
|
||||
async def _get_project_by_repo(self, repo_url: str) -> Optional[Project]:
|
||||
stmt = select(Project).where(Project.repository_url == repo_url)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str) -> str:
|
||||
async def _prepare_repository(self, project: Project, repo_url: str, branch: str, token: str, pr_number: Optional[int] = None) -> str:
|
||||
"""
|
||||
Clones or Updates the repository locally.
|
||||
"""
|
||||
|
|
@ -381,6 +388,13 @@ class CIService:
|
|||
try:
|
||||
# git fetch --all
|
||||
subprocess.run(["git", "fetch", "--all"], cwd=target_dir, check=True)
|
||||
|
||||
if pr_number:
|
||||
# Fetch PR ref specifically from base repo: refs/pull/ID/head
|
||||
logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
|
||||
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
|
||||
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
|
||||
else:
|
||||
# git checkout branch
|
||||
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
|
||||
# git reset --hard origin/branch
|
||||
|
|
@ -388,12 +402,20 @@ class CIService:
|
|||
except Exception as e:
|
||||
logger.error(f"Git update failed: {e}. Re-cloning...")
|
||||
shutil.rmtree(target_dir) # Nuke and retry
|
||||
return await self._prepare_repository(project, repo_url, branch, token)
|
||||
return await self._prepare_repository(project, repo_url, branch, token, pr_number=pr_number)
|
||||
else:
|
||||
# Clone
|
||||
logger.info(f"📥 Cloning repo to {target_dir}")
|
||||
try:
|
||||
subprocess.run(["git", "clone", "-b", branch, auth_url, str(target_dir)], check=True)
|
||||
# Clone without -b first, then fetch and checkout
|
||||
subprocess.run(["git", "clone", auth_url, str(target_dir)], check=True)
|
||||
|
||||
if pr_number:
|
||||
logger.info(f"📥 Fetching PR ref: refs/pull/{pr_number}/head")
|
||||
subprocess.run(["git", "fetch", "origin", f"refs/pull/{pr_number}/head"], cwd=target_dir, check=True)
|
||||
subprocess.run(["git", "checkout", "FETCH_HEAD"], cwd=target_dir, check=True)
|
||||
else:
|
||||
subprocess.run(["git", "checkout", branch], cwd=target_dir, check=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Git clone failed: {e}")
|
||||
raise e
|
||||
|
|
@ -412,6 +434,9 @@ class CIService:
|
|||
resp = await client.get(api_url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
return resp.text
|
||||
elif resp.status_code == 403:
|
||||
logger.error(f"❌ Failed to fetch diff: 403 Forbidden. This usually means the GITEA_BOT_TOKEN lacks 'read:repository' scope. Response: {resp.text[:200]}")
|
||||
return ""
|
||||
else:
|
||||
logger.error(f"Failed to fetch diff: {resp.status_code} - {resp.text[:200]}")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class OpenAIEmbedding(EmbeddingProvider):
|
|||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-ada-002": 1536,
|
||||
"Qwen3-Embedding-4B": 2560,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
|
@ -73,7 +74,13 @@ class OpenAIEmbedding(EmbeddingProvider):
|
|||
or "https://api.openai.com/v1"
|
||||
)
|
||||
self.model = model
|
||||
self._dimension = self.MODELS.get(model, 1536)
|
||||
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536
|
||||
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
|
||||
self._dimension = (
|
||||
self._explicit_dimension
|
||||
if self._explicit_dimension > 0
|
||||
else self.MODELS.get(model, 1536)
|
||||
)
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
|
|
@ -130,6 +137,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
|
|||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-ada-002": 1536,
|
||||
"Qwen3-Embedding-4B": 2560,
|
||||
}
|
||||
|
||||
# 最新的 GA API 版本
|
||||
|
|
@ -144,7 +152,13 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
|
|||
self.api_key = api_key
|
||||
self.base_url = base_url or "https://your-resource.openai.azure.com"
|
||||
self.model = model
|
||||
self._dimension = self.MODELS.get(model, 1536)
|
||||
# 优先使用项目配置中的显式维度
|
||||
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
|
||||
self._dimension = (
|
||||
self._explicit_dimension
|
||||
if self._explicit_dimension > 0
|
||||
else self.MODELS.get(model, 1536)
|
||||
)
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
|
|
@ -481,6 +495,7 @@ class QwenEmbedding(EmbeddingProvider):
|
|||
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v2": 1536, # 支持维度: 1536
|
||||
"Qwen3-Embedding-4B": 2560,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
|
@ -505,7 +520,14 @@ class QwenEmbedding(EmbeddingProvider):
|
|||
# DashScope 兼容 OpenAI 的 embeddings 端点
|
||||
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
self.model = model
|
||||
self._dimension = self.MODELS.get(model, 1024)
|
||||
|
||||
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024
|
||||
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
|
||||
self._dimension = (
|
||||
self._explicit_dimension
|
||||
if self._explicit_dimension > 0
|
||||
else self.MODELS.get(model, 1024)
|
||||
)
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
|
|
|
|||
|
|
@ -763,20 +763,22 @@ class CodeIndexer:
|
|||
Returns:
|
||||
(needs_rebuild, reason) - 是否需要重建及原因
|
||||
"""
|
||||
if self._initialized and not force_rebuild:
|
||||
return self._needs_rebuild, self._rebuild_reason
|
||||
# 如果 force_rebuild 为真,或者尚未初始化且 _needs_rebuild 已通过某种方式置为真
|
||||
should_recreate = force_rebuild or (not self._initialized and self._needs_rebuild)
|
||||
|
||||
# 先初始化 vector_store(不强制重建,只是获取现有 collection)
|
||||
await self.vector_store.initialize(force_recreate=False)
|
||||
# 先初始化 vector_store
|
||||
await self.vector_store.initialize(force_recreate=should_recreate)
|
||||
|
||||
if should_recreate:
|
||||
self._needs_rebuild = True
|
||||
self._rebuild_reason = "强制重建"
|
||||
self._initialized = True
|
||||
return True, self._rebuild_reason
|
||||
|
||||
# 检查是否需要重建
|
||||
self._needs_rebuild, self._rebuild_reason = await self._check_rebuild_needed()
|
||||
|
||||
if force_rebuild:
|
||||
self._needs_rebuild = True
|
||||
self._rebuild_reason = "用户强制重建"
|
||||
|
||||
# 如果需要重建,重新初始化 vector_store(强制重建)
|
||||
# 如果自动检测到需要重建,则再次初始化并强制重建
|
||||
if self._needs_rebuild:
|
||||
logger.info(f"🔄 需要重建索引: {self._rebuild_reason}")
|
||||
await self.vector_store.initialize(force_recreate=True)
|
||||
|
|
@ -819,11 +821,30 @@ class CodeIndexer:
|
|||
# 检查维度
|
||||
stored_dimension = stored_config.get("dimension")
|
||||
current_dimension = self.embedding_config.get("dimension")
|
||||
if stored_dimension and current_dimension and stored_dimension != current_dimension:
|
||||
|
||||
# 🔥 如果 metadata 中没有维度,尝试从 sample 中检测
|
||||
if not stored_dimension:
|
||||
stored_dimension = await self._detect_actual_dimension()
|
||||
if stored_dimension:
|
||||
logger.debug(f"🔍 从 sample 检测到实际维度: {stored_dimension}")
|
||||
|
||||
if stored_dimension and current_dimension and int(stored_dimension) != int(current_dimension):
|
||||
return True, f"Embedding 维度变更: {stored_dimension} -> {current_dimension}"
|
||||
|
||||
return False, ""
|
||||
|
||||
async def _detect_actual_dimension(self) -> Optional[int]:
|
||||
"""从既存向量中检测实际维度"""
|
||||
try:
|
||||
if hasattr(self.vector_store, '_collection') and self.vector_store._collection:
|
||||
peek = await asyncio.to_thread(self.vector_store._collection.peek, limit=1)
|
||||
embeddings = peek.get("embeddings")
|
||||
if embeddings and len(embeddings) > 0:
|
||||
return len(embeddings[0])
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def get_index_status(self) -> IndexStatus:
|
||||
"""获取索引状态"""
|
||||
await self.initialize()
|
||||
|
|
@ -873,7 +894,7 @@ class CodeIndexer:
|
|||
索引进度
|
||||
"""
|
||||
# 初始化并检查是否需要重建
|
||||
needs_rebuild, rebuild_reason = await self.initialize()
|
||||
needs_rebuild, rebuild_reason = await self.initialize(force_rebuild=(update_mode == IndexUpdateMode.FULL))
|
||||
|
||||
progress = IndexingProgress()
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
|
|
|||
|
|
@ -164,6 +164,12 @@ class CodeRetriever:
|
|||
if not api_key and self._provided_embedding_service:
|
||||
api_key = getattr(self._provided_embedding_service, 'api_key', None)
|
||||
|
||||
# 🔥 重要:如果用户显式指定了维度,且与存储的维度不匹配,则不应自动切换(会导致报错)
|
||||
explicit_dim = getattr(settings, "EMBEDDING_DIMENSION", 0)
|
||||
if explicit_dim > 0 and explicit_dim != stored_dimension:
|
||||
logger.warning(f"⚠️ Collection 维度 ({stored_dimension}) 与显式指定的维度 ({explicit_dim}) 不匹配,跳过自动切换以避免错误。")
|
||||
return
|
||||
|
||||
self.embedding_service = EmbeddingService(
|
||||
provider=stored_provider,
|
||||
model=stored_model,
|
||||
|
|
|
|||
|
|
@ -132,6 +132,9 @@ EMBEDDING_PROVIDER=openai
|
|||
# Ollama: nomic-embed-text, mxbai-embed-large
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
# 嵌入模型维度
|
||||
EMBEDDING_DIMENSION=2560
|
||||
|
||||
# 嵌入模型 API Key(留空则使用 LLM_API_KEY)
|
||||
EMBEDDING_API_KEY=
|
||||
|
||||
|
|
@ -219,5 +222,7 @@ ZIP_STORAGE_PATH=./uploads/zip_files
|
|||
# zh-CN: 中文
|
||||
# en-US: 英文
|
||||
OUTPUT_LANGUAGE=zh-CN
|
||||
|
||||
# Gitea 配置
|
||||
GITEA_HOST_URL=
|
||||
GITEA_BOT_TOKEN=
|
||||
Loading…
Reference in New Issue