CodeReview/backend/app/services/rag/indexer.py

586 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
代码索引器
将代码分块并索引到向量数据库
"""
import os
import asyncio
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable
from pathlib import Path
from dataclasses import dataclass
import json
from .splitter import CodeSplitter, CodeChunk
from .embeddings import EmbeddingService
logger = logging.getLogger(__name__)
# 支持的文本文件扩展名
TEXT_EXTENSIONS = {
".py", ".js", ".ts", ".tsx", ".jsx", ".java", ".go", ".rs",
".cpp", ".c", ".h", ".cc", ".hh", ".cs", ".php", ".rb",
".kt", ".swift", ".sql", ".sh", ".json", ".yml", ".yaml",
".xml", ".html", ".css", ".vue", ".svelte", ".md",
}
# 排除的目录
EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "out", "bin", "obj",
"__MACOSX", ".next", ".nuxt", "venv", "env", ".env",
}
# 排除的文件
EXCLUDE_FILES = {
".DS_Store", "package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"Cargo.lock", "poetry.lock", "composer.lock", "Gemfile.lock",
}
@dataclass
class IndexingProgress:
"""索引进度"""
total_files: int = 0
processed_files: int = 0
total_chunks: int = 0
indexed_chunks: int = 0
current_file: str = ""
errors: List[str] = None
def __post_init__(self):
if self.errors is None:
self.errors = []
@property
def progress_percentage(self) -> float:
if self.total_files == 0:
return 0.0
return (self.processed_files / self.total_files) * 100
@dataclass
class IndexingResult:
"""索引结果"""
success: bool
total_files: int
indexed_files: int
total_chunks: int
errors: List[str]
collection_name: str
class VectorStore:
"""向量存储抽象基类"""
async def initialize(self):
"""初始化存储"""
pass
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档"""
raise NotImplementedError
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询"""
raise NotImplementedError
async def delete_collection(self):
"""删除集合"""
raise NotImplementedError
async def get_count(self) -> int:
"""获取文档数量"""
raise NotImplementedError
class ChromaVectorStore(VectorStore):
"""Chroma 向量存储"""
def __init__(
self,
collection_name: str,
persist_directory: Optional[str] = None,
):
self.collection_name = collection_name
self.persist_directory = persist_directory
self._client = None
self._collection = None
async def initialize(self):
"""初始化 Chroma"""
try:
import chromadb
from chromadb.config import Settings
if self.persist_directory:
self._client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(anonymized_telemetry=False),
)
else:
self._client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"Chroma collection '{self.collection_name}' initialized")
except ImportError:
raise ImportError("chromadb is required. Install with: pip install chromadb")
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档到 Chroma"""
if not ids:
return
# Chroma 对元数据有限制,需要清理
cleaned_metadatas = []
for meta in metadatas:
cleaned = {}
for k, v in meta.items():
if isinstance(v, (str, int, float, bool)):
cleaned[k] = v
elif isinstance(v, list):
# 列表转为 JSON 字符串
cleaned[k] = json.dumps(v)
elif v is not None:
cleaned[k] = str(v)
cleaned_metadatas.append(cleaned)
# 分批添加Chroma 批次限制)
batch_size = 500
for i in range(0, len(ids), batch_size):
batch_ids = ids[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
batch_documents = documents[i:i + batch_size]
batch_metadatas = cleaned_metadatas[i:i + batch_size]
await asyncio.to_thread(
self._collection.add,
ids=batch_ids,
embeddings=batch_embeddings,
documents=batch_documents,
metadatas=batch_metadatas,
)
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询 Chroma"""
result = await asyncio.to_thread(
self._collection.query,
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
include=["documents", "metadatas", "distances"],
)
return {
"ids": result["ids"][0] if result["ids"] else [],
"documents": result["documents"][0] if result["documents"] else [],
"metadatas": result["metadatas"][0] if result["metadatas"] else [],
"distances": result["distances"][0] if result["distances"] else [],
}
async def delete_collection(self):
"""删除集合"""
if self._client and self._collection:
await asyncio.to_thread(
self._client.delete_collection,
name=self.collection_name,
)
async def get_count(self) -> int:
"""获取文档数量"""
if self._collection:
return await asyncio.to_thread(self._collection.count)
return 0
class InMemoryVectorStore(VectorStore):
"""内存向量存储(用于测试或小项目)"""
def __init__(self, collection_name: str):
self.collection_name = collection_name
self._documents: Dict[str, Dict[str, Any]] = {}
async def initialize(self):
"""初始化"""
logger.info(f"InMemory vector store '{self.collection_name}' initialized")
async def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
documents: List[str],
metadatas: List[Dict[str, Any]],
):
"""添加文档"""
for id_, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
self._documents[id_] = {
"embedding": emb,
"document": doc,
"metadata": meta,
}
async def query(
self,
query_embedding: List[float],
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""查询(使用余弦相似度)"""
import math
def cosine_similarity(a: List[float], b: List[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
results = []
for id_, data in self._documents.items():
# 应用过滤条件
if where:
match = True
for k, v in where.items():
if data["metadata"].get(k) != v:
match = False
break
if not match:
continue
similarity = cosine_similarity(query_embedding, data["embedding"])
results.append({
"id": id_,
"document": data["document"],
"metadata": data["metadata"],
"distance": 1 - similarity, # 转换为距离
})
# 按距离排序
results.sort(key=lambda x: x["distance"])
results = results[:n_results]
return {
"ids": [r["id"] for r in results],
"documents": [r["document"] for r in results],
"metadatas": [r["metadata"] for r in results],
"distances": [r["distance"] for r in results],
}
async def delete_collection(self):
"""删除集合"""
self._documents.clear()
async def get_count(self) -> int:
"""获取文档数量"""
return len(self._documents)
class CodeIndexer:
"""
代码索引器
将代码文件分块、嵌入并索引到向量数据库
"""
def __init__(
self,
collection_name: str,
embedding_service: Optional[EmbeddingService] = None,
vector_store: Optional[VectorStore] = None,
splitter: Optional[CodeSplitter] = None,
persist_directory: Optional[str] = None,
):
"""
初始化索引器
Args:
collection_name: 向量集合名称
embedding_service: 嵌入服务
vector_store: 向量存储
splitter: 代码分块器
persist_directory: 持久化目录
"""
self.collection_name = collection_name
self.embedding_service = embedding_service or EmbeddingService()
self.splitter = splitter or CodeSplitter()
# 创建向量存储
if vector_store:
self.vector_store = vector_store
else:
try:
self.vector_store = ChromaVectorStore(
collection_name=collection_name,
persist_directory=persist_directory,
)
except ImportError:
logger.warning("Chroma not available, using in-memory store")
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
self._initialized = False
async def initialize(self):
"""初始化索引器"""
if not self._initialized:
await self.vector_store.initialize()
self._initialized = True
async def index_directory(
self,
directory: str,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
) -> AsyncGenerator[IndexingProgress, None]:
"""
索引目录中的代码文件
Args:
directory: 目录路径
exclude_patterns: 排除模式
include_patterns: 包含模式
progress_callback: 进度回调
Yields:
索引进度
"""
await self.initialize()
progress = IndexingProgress()
exclude_patterns = exclude_patterns or []
# 收集文件
files = self._collect_files(directory, exclude_patterns, include_patterns)
progress.total_files = len(files)
logger.info(f"Found {len(files)} files to index in {directory}")
yield progress
all_chunks: List[CodeChunk] = []
# 分块处理文件
for file_path in files:
progress.current_file = file_path
try:
relative_path = os.path.relpath(file_path, directory)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
if not content.strip():
progress.processed_files += 1
continue
# 限制文件大小
if len(content) > 500000: # 500KB
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
all_chunks.extend(chunks)
progress.processed_files += 1
progress.total_chunks = len(all_chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"Error processing {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files")
# 批量嵌入和索引
if all_chunks:
await self._index_chunks(all_chunks, progress)
progress.indexed_chunks = len(all_chunks)
yield progress
async def index_files(
self,
files: List[Dict[str, str]],
base_path: str = "",
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
) -> AsyncGenerator[IndexingProgress, None]:
"""
索引文件列表
Args:
files: 文件列表 [{"path": "...", "content": "..."}]
base_path: 基础路径
progress_callback: 进度回调
Yields:
索引进度
"""
await self.initialize()
progress = IndexingProgress()
progress.total_files = len(files)
all_chunks: List[CodeChunk] = []
for file_info in files:
file_path = file_info.get("path", "")
content = file_info.get("content", "")
progress.current_file = file_path
try:
if not content.strip():
progress.processed_files += 1
continue
# 限制文件大小
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, file_path)
all_chunks.extend(chunks)
progress.processed_files += 1
progress.total_chunks = len(all_chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"Error processing {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
# 批量嵌入和索引
if all_chunks:
await self._index_chunks(all_chunks, progress)
progress.indexed_chunks = len(all_chunks)
yield progress
async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress):
"""索引代码块"""
# 准备嵌入文本
texts = [chunk.to_embedding_text() for chunk in chunks]
logger.info(f"Generating embeddings for {len(texts)} chunks...")
# 批量嵌入
embeddings = await self.embedding_service.embed_batch(texts, batch_size=50)
# 准备元数据
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [chunk.to_dict() for chunk in chunks]
# 添加到向量存储
logger.info(f"Adding {len(chunks)} chunks to vector store...")
await self.vector_store.add_documents(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
logger.info(f"Indexed {len(chunks)} chunks successfully")
def _collect_files(
self,
directory: str,
exclude_patterns: List[str],
include_patterns: Optional[List[str]],
) -> List[str]:
"""收集需要索引的文件"""
import fnmatch
files = []
for root, dirs, filenames in os.walk(directory):
# 过滤目录
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
for filename in filenames:
# 检查扩展名
ext = os.path.splitext(filename)[1].lower()
if ext not in TEXT_EXTENSIONS:
continue
# 检查排除文件
if filename in EXCLUDE_FILES:
continue
file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, directory)
# 检查排除模式
excluded = False
for pattern in exclude_patterns:
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
excluded = True
break
if excluded:
continue
# 检查包含模式
if include_patterns:
included = False
for pattern in include_patterns:
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
included = True
break
if not included:
continue
files.append(file_path)
return files
async def get_chunk_count(self) -> int:
"""获取已索引的代码块数量"""
await self.initialize()
return await self.vector_store.get_count()
async def clear(self):
"""清空索引"""
await self.initialize()
await self.vector_store.delete_collection()
self._initialized = False