586 lines
18 KiB
Python
586 lines
18 KiB
Python
"""
|
||
代码索引器
|
||
将代码分块并索引到向量数据库
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
import logging
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable
|
||
from pathlib import Path
|
||
from dataclasses import dataclass
|
||
import json
|
||
|
||
from .splitter import CodeSplitter, CodeChunk
|
||
from .embeddings import EmbeddingService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 支持的文本文件扩展名
|
||
TEXT_EXTENSIONS = {
|
||
".py", ".js", ".ts", ".tsx", ".jsx", ".java", ".go", ".rs",
|
||
".cpp", ".c", ".h", ".cc", ".hh", ".cs", ".php", ".rb",
|
||
".kt", ".swift", ".sql", ".sh", ".json", ".yml", ".yaml",
|
||
".xml", ".html", ".css", ".vue", ".svelte", ".md",
|
||
}
|
||
|
||
# 排除的目录
|
||
EXCLUDE_DIRS = {
|
||
"node_modules", "vendor", "dist", "build", ".git",
|
||
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
|
||
".vscode", ".idea", ".vs", "target", "out", "bin", "obj",
|
||
"__MACOSX", ".next", ".nuxt", "venv", "env", ".env",
|
||
}
|
||
|
||
# 排除的文件
|
||
EXCLUDE_FILES = {
|
||
".DS_Store", "package-lock.json", "yarn.lock", "pnpm-lock.yaml",
|
||
"Cargo.lock", "poetry.lock", "composer.lock", "Gemfile.lock",
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class IndexingProgress:
|
||
"""索引进度"""
|
||
total_files: int = 0
|
||
processed_files: int = 0
|
||
total_chunks: int = 0
|
||
indexed_chunks: int = 0
|
||
current_file: str = ""
|
||
errors: List[str] = None
|
||
|
||
def __post_init__(self):
|
||
if self.errors is None:
|
||
self.errors = []
|
||
|
||
@property
|
||
def progress_percentage(self) -> float:
|
||
if self.total_files == 0:
|
||
return 0.0
|
||
return (self.processed_files / self.total_files) * 100
|
||
|
||
|
||
@dataclass
|
||
class IndexingResult:
|
||
"""索引结果"""
|
||
success: bool
|
||
total_files: int
|
||
indexed_files: int
|
||
total_chunks: int
|
||
errors: List[str]
|
||
collection_name: str
|
||
|
||
|
||
class VectorStore:
|
||
"""向量存储抽象基类"""
|
||
|
||
async def initialize(self):
|
||
"""初始化存储"""
|
||
pass
|
||
|
||
async def add_documents(
|
||
self,
|
||
ids: List[str],
|
||
embeddings: List[List[float]],
|
||
documents: List[str],
|
||
metadatas: List[Dict[str, Any]],
|
||
):
|
||
"""添加文档"""
|
||
raise NotImplementedError
|
||
|
||
async def query(
|
||
self,
|
||
query_embedding: List[float],
|
||
n_results: int = 10,
|
||
where: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""查询"""
|
||
raise NotImplementedError
|
||
|
||
async def delete_collection(self):
|
||
"""删除集合"""
|
||
raise NotImplementedError
|
||
|
||
async def get_count(self) -> int:
|
||
"""获取文档数量"""
|
||
raise NotImplementedError
|
||
|
||
|
||
class ChromaVectorStore(VectorStore):
|
||
"""Chroma 向量存储"""
|
||
|
||
def __init__(
|
||
self,
|
||
collection_name: str,
|
||
persist_directory: Optional[str] = None,
|
||
):
|
||
self.collection_name = collection_name
|
||
self.persist_directory = persist_directory
|
||
self._client = None
|
||
self._collection = None
|
||
|
||
async def initialize(self):
|
||
"""初始化 Chroma"""
|
||
try:
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
|
||
if self.persist_directory:
|
||
self._client = chromadb.PersistentClient(
|
||
path=self.persist_directory,
|
||
settings=Settings(anonymized_telemetry=False),
|
||
)
|
||
else:
|
||
self._client = chromadb.Client(
|
||
settings=Settings(anonymized_telemetry=False),
|
||
)
|
||
|
||
self._collection = self._client.get_or_create_collection(
|
||
name=self.collection_name,
|
||
metadata={"hnsw:space": "cosine"},
|
||
)
|
||
|
||
logger.info(f"Chroma collection '{self.collection_name}' initialized")
|
||
|
||
except ImportError:
|
||
raise ImportError("chromadb is required. Install with: pip install chromadb")
|
||
|
||
async def add_documents(
|
||
self,
|
||
ids: List[str],
|
||
embeddings: List[List[float]],
|
||
documents: List[str],
|
||
metadatas: List[Dict[str, Any]],
|
||
):
|
||
"""添加文档到 Chroma"""
|
||
if not ids:
|
||
return
|
||
|
||
# Chroma 对元数据有限制,需要清理
|
||
cleaned_metadatas = []
|
||
for meta in metadatas:
|
||
cleaned = {}
|
||
for k, v in meta.items():
|
||
if isinstance(v, (str, int, float, bool)):
|
||
cleaned[k] = v
|
||
elif isinstance(v, list):
|
||
# 列表转为 JSON 字符串
|
||
cleaned[k] = json.dumps(v)
|
||
elif v is not None:
|
||
cleaned[k] = str(v)
|
||
cleaned_metadatas.append(cleaned)
|
||
|
||
# 分批添加(Chroma 批次限制)
|
||
batch_size = 500
|
||
for i in range(0, len(ids), batch_size):
|
||
batch_ids = ids[i:i + batch_size]
|
||
batch_embeddings = embeddings[i:i + batch_size]
|
||
batch_documents = documents[i:i + batch_size]
|
||
batch_metadatas = cleaned_metadatas[i:i + batch_size]
|
||
|
||
await asyncio.to_thread(
|
||
self._collection.add,
|
||
ids=batch_ids,
|
||
embeddings=batch_embeddings,
|
||
documents=batch_documents,
|
||
metadatas=batch_metadatas,
|
||
)
|
||
|
||
async def query(
|
||
self,
|
||
query_embedding: List[float],
|
||
n_results: int = 10,
|
||
where: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""查询 Chroma"""
|
||
result = await asyncio.to_thread(
|
||
self._collection.query,
|
||
query_embeddings=[query_embedding],
|
||
n_results=n_results,
|
||
where=where,
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
|
||
return {
|
||
"ids": result["ids"][0] if result["ids"] else [],
|
||
"documents": result["documents"][0] if result["documents"] else [],
|
||
"metadatas": result["metadatas"][0] if result["metadatas"] else [],
|
||
"distances": result["distances"][0] if result["distances"] else [],
|
||
}
|
||
|
||
async def delete_collection(self):
|
||
"""删除集合"""
|
||
if self._client and self._collection:
|
||
await asyncio.to_thread(
|
||
self._client.delete_collection,
|
||
name=self.collection_name,
|
||
)
|
||
|
||
async def get_count(self) -> int:
|
||
"""获取文档数量"""
|
||
if self._collection:
|
||
return await asyncio.to_thread(self._collection.count)
|
||
return 0
|
||
|
||
|
||
class InMemoryVectorStore(VectorStore):
|
||
"""内存向量存储(用于测试或小项目)"""
|
||
|
||
def __init__(self, collection_name: str):
|
||
self.collection_name = collection_name
|
||
self._documents: Dict[str, Dict[str, Any]] = {}
|
||
|
||
async def initialize(self):
|
||
"""初始化"""
|
||
logger.info(f"InMemory vector store '{self.collection_name}' initialized")
|
||
|
||
async def add_documents(
|
||
self,
|
||
ids: List[str],
|
||
embeddings: List[List[float]],
|
||
documents: List[str],
|
||
metadatas: List[Dict[str, Any]],
|
||
):
|
||
"""添加文档"""
|
||
for id_, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
||
self._documents[id_] = {
|
||
"embedding": emb,
|
||
"document": doc,
|
||
"metadata": meta,
|
||
}
|
||
|
||
async def query(
|
||
self,
|
||
query_embedding: List[float],
|
||
n_results: int = 10,
|
||
where: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""查询(使用余弦相似度)"""
|
||
import math
|
||
|
||
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
norm_a = math.sqrt(sum(x * x for x in a))
|
||
norm_b = math.sqrt(sum(x * x for x in b))
|
||
if norm_a == 0 or norm_b == 0:
|
||
return 0.0
|
||
return dot / (norm_a * norm_b)
|
||
|
||
results = []
|
||
for id_, data in self._documents.items():
|
||
# 应用过滤条件
|
||
if where:
|
||
match = True
|
||
for k, v in where.items():
|
||
if data["metadata"].get(k) != v:
|
||
match = False
|
||
break
|
||
if not match:
|
||
continue
|
||
|
||
similarity = cosine_similarity(query_embedding, data["embedding"])
|
||
results.append({
|
||
"id": id_,
|
||
"document": data["document"],
|
||
"metadata": data["metadata"],
|
||
"distance": 1 - similarity, # 转换为距离
|
||
})
|
||
|
||
# 按距离排序
|
||
results.sort(key=lambda x: x["distance"])
|
||
results = results[:n_results]
|
||
|
||
return {
|
||
"ids": [r["id"] for r in results],
|
||
"documents": [r["document"] for r in results],
|
||
"metadatas": [r["metadata"] for r in results],
|
||
"distances": [r["distance"] for r in results],
|
||
}
|
||
|
||
async def delete_collection(self):
|
||
"""删除集合"""
|
||
self._documents.clear()
|
||
|
||
async def get_count(self) -> int:
|
||
"""获取文档数量"""
|
||
return len(self._documents)
|
||
|
||
|
||
class CodeIndexer:
|
||
"""
|
||
代码索引器
|
||
将代码文件分块、嵌入并索引到向量数据库
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
collection_name: str,
|
||
embedding_service: Optional[EmbeddingService] = None,
|
||
vector_store: Optional[VectorStore] = None,
|
||
splitter: Optional[CodeSplitter] = None,
|
||
persist_directory: Optional[str] = None,
|
||
):
|
||
"""
|
||
初始化索引器
|
||
|
||
Args:
|
||
collection_name: 向量集合名称
|
||
embedding_service: 嵌入服务
|
||
vector_store: 向量存储
|
||
splitter: 代码分块器
|
||
persist_directory: 持久化目录
|
||
"""
|
||
self.collection_name = collection_name
|
||
self.embedding_service = embedding_service or EmbeddingService()
|
||
self.splitter = splitter or CodeSplitter()
|
||
|
||
# 创建向量存储
|
||
if vector_store:
|
||
self.vector_store = vector_store
|
||
else:
|
||
try:
|
||
self.vector_store = ChromaVectorStore(
|
||
collection_name=collection_name,
|
||
persist_directory=persist_directory,
|
||
)
|
||
except ImportError:
|
||
logger.warning("Chroma not available, using in-memory store")
|
||
self.vector_store = InMemoryVectorStore(collection_name=collection_name)
|
||
|
||
self._initialized = False
|
||
|
||
async def initialize(self):
|
||
"""初始化索引器"""
|
||
if not self._initialized:
|
||
await self.vector_store.initialize()
|
||
self._initialized = True
|
||
|
||
async def index_directory(
|
||
self,
|
||
directory: str,
|
||
exclude_patterns: Optional[List[str]] = None,
|
||
include_patterns: Optional[List[str]] = None,
|
||
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
|
||
) -> AsyncGenerator[IndexingProgress, None]:
|
||
"""
|
||
索引目录中的代码文件
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
exclude_patterns: 排除模式
|
||
include_patterns: 包含模式
|
||
progress_callback: 进度回调
|
||
|
||
Yields:
|
||
索引进度
|
||
"""
|
||
await self.initialize()
|
||
|
||
progress = IndexingProgress()
|
||
exclude_patterns = exclude_patterns or []
|
||
|
||
# 收集文件
|
||
files = self._collect_files(directory, exclude_patterns, include_patterns)
|
||
progress.total_files = len(files)
|
||
|
||
logger.info(f"Found {len(files)} files to index in {directory}")
|
||
yield progress
|
||
|
||
all_chunks: List[CodeChunk] = []
|
||
|
||
# 分块处理文件
|
||
for file_path in files:
|
||
progress.current_file = file_path
|
||
|
||
try:
|
||
relative_path = os.path.relpath(file_path, directory)
|
||
|
||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
if not content.strip():
|
||
progress.processed_files += 1
|
||
continue
|
||
|
||
# 限制文件大小
|
||
if len(content) > 500000: # 500KB
|
||
content = content[:500000]
|
||
|
||
# 分块
|
||
chunks = self.splitter.split_file(content, relative_path)
|
||
all_chunks.extend(chunks)
|
||
|
||
progress.processed_files += 1
|
||
progress.total_chunks = len(all_chunks)
|
||
|
||
if progress_callback:
|
||
progress_callback(progress)
|
||
yield progress
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error processing {file_path}: {e}")
|
||
progress.errors.append(f"{file_path}: {str(e)}")
|
||
progress.processed_files += 1
|
||
|
||
logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files")
|
||
|
||
# 批量嵌入和索引
|
||
if all_chunks:
|
||
await self._index_chunks(all_chunks, progress)
|
||
|
||
progress.indexed_chunks = len(all_chunks)
|
||
yield progress
|
||
|
||
async def index_files(
|
||
self,
|
||
files: List[Dict[str, str]],
|
||
base_path: str = "",
|
||
progress_callback: Optional[Callable[[IndexingProgress], None]] = None,
|
||
) -> AsyncGenerator[IndexingProgress, None]:
|
||
"""
|
||
索引文件列表
|
||
|
||
Args:
|
||
files: 文件列表 [{"path": "...", "content": "..."}]
|
||
base_path: 基础路径
|
||
progress_callback: 进度回调
|
||
|
||
Yields:
|
||
索引进度
|
||
"""
|
||
await self.initialize()
|
||
|
||
progress = IndexingProgress()
|
||
progress.total_files = len(files)
|
||
|
||
all_chunks: List[CodeChunk] = []
|
||
|
||
for file_info in files:
|
||
file_path = file_info.get("path", "")
|
||
content = file_info.get("content", "")
|
||
|
||
progress.current_file = file_path
|
||
|
||
try:
|
||
if not content.strip():
|
||
progress.processed_files += 1
|
||
continue
|
||
|
||
# 限制文件大小
|
||
if len(content) > 500000:
|
||
content = content[:500000]
|
||
|
||
# 分块
|
||
chunks = self.splitter.split_file(content, file_path)
|
||
all_chunks.extend(chunks)
|
||
|
||
progress.processed_files += 1
|
||
progress.total_chunks = len(all_chunks)
|
||
|
||
if progress_callback:
|
||
progress_callback(progress)
|
||
yield progress
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error processing {file_path}: {e}")
|
||
progress.errors.append(f"{file_path}: {str(e)}")
|
||
progress.processed_files += 1
|
||
|
||
# 批量嵌入和索引
|
||
if all_chunks:
|
||
await self._index_chunks(all_chunks, progress)
|
||
|
||
progress.indexed_chunks = len(all_chunks)
|
||
yield progress
|
||
|
||
async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress):
|
||
"""索引代码块"""
|
||
# 准备嵌入文本
|
||
texts = [chunk.to_embedding_text() for chunk in chunks]
|
||
|
||
logger.info(f"Generating embeddings for {len(texts)} chunks...")
|
||
|
||
# 批量嵌入
|
||
embeddings = await self.embedding_service.embed_batch(texts, batch_size=50)
|
||
|
||
# 准备元数据
|
||
ids = [chunk.id for chunk in chunks]
|
||
documents = [chunk.content for chunk in chunks]
|
||
metadatas = [chunk.to_dict() for chunk in chunks]
|
||
|
||
# 添加到向量存储
|
||
logger.info(f"Adding {len(chunks)} chunks to vector store...")
|
||
await self.vector_store.add_documents(
|
||
ids=ids,
|
||
embeddings=embeddings,
|
||
documents=documents,
|
||
metadatas=metadatas,
|
||
)
|
||
|
||
logger.info(f"Indexed {len(chunks)} chunks successfully")
|
||
|
||
def _collect_files(
|
||
self,
|
||
directory: str,
|
||
exclude_patterns: List[str],
|
||
include_patterns: Optional[List[str]],
|
||
) -> List[str]:
|
||
"""收集需要索引的文件"""
|
||
import fnmatch
|
||
|
||
files = []
|
||
|
||
for root, dirs, filenames in os.walk(directory):
|
||
# 过滤目录
|
||
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
|
||
|
||
for filename in filenames:
|
||
# 检查扩展名
|
||
ext = os.path.splitext(filename)[1].lower()
|
||
if ext not in TEXT_EXTENSIONS:
|
||
continue
|
||
|
||
# 检查排除文件
|
||
if filename in EXCLUDE_FILES:
|
||
continue
|
||
|
||
file_path = os.path.join(root, filename)
|
||
relative_path = os.path.relpath(file_path, directory)
|
||
|
||
# 检查排除模式
|
||
excluded = False
|
||
for pattern in exclude_patterns:
|
||
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||
excluded = True
|
||
break
|
||
|
||
if excluded:
|
||
continue
|
||
|
||
# 检查包含模式
|
||
if include_patterns:
|
||
included = False
|
||
for pattern in include_patterns:
|
||
if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||
included = True
|
||
break
|
||
if not included:
|
||
continue
|
||
|
||
files.append(file_path)
|
||
|
||
return files
|
||
|
||
async def get_chunk_count(self) -> int:
|
||
"""获取已索引的代码块数量"""
|
||
await self.initialize()
|
||
return await self.vector_store.get_count()
|
||
|
||
async def clear(self):
|
||
"""清空索引"""
|
||
await self.initialize()
|
||
await self.vector_store.delete_collection()
|
||
self._initialized = False
|
||
|