diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 9264657..346f2d3 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -106,7 +106,15 @@ class Settings(BaseSettings): AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算 AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟) + # 嵌入并发布配置 + EMBEDDING_CONCURRENCY: int = 2 # 经计算,Batch=10 时并发 2 最接近 1.2M TPM 限流线 + EMBEDDING_RETRY_MAX: int = 10 # 嵌入模型最大重试次数 + + + + # 沙箱配置(必须) + SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像 SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制 SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制 diff --git a/backend/app/services/rag/embeddings.py b/backend/app/services/rag/embeddings.py index 894024b..bcef650 100644 --- a/backend/app/services/rag/embeddings.py +++ b/backend/app/services/rag/embeddings.py @@ -34,7 +34,7 @@ class EmbeddingProvider(ABC): pass @abstractmethod - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: """批量嵌入文本""" pass @@ -90,7 +90,7 @@ class OpenAIEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -109,20 +109,25 @@ class OpenAIEmbedding(EmbeddingProvider): url = f"{self.base_url.rstrip('/')}/embeddings" - async with httpx.AsyncClient(timeout=60) as client: + if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() - results = [] - for item in data.get("data", []): - results.append(EmbeddingResult( - embedding=item["embedding"], - tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), - model=self.model, - )) - - return results + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results class AzureOpenAIEmbedding(EmbeddingProvider): @@ -168,7 +173,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -187,20 +192,25 @@ class AzureOpenAIEmbedding(EmbeddingProvider): # Azure URL 格式 - 使用最新 API 版本 url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}" - async with httpx.AsyncClient(timeout=60) as client: + if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() - results = [] - for item in data.get("data", []): - results.append(EmbeddingResult( - embedding=item["embedding"], - tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), - model=self.model, - )) - - return results + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results class OllamaEmbedding(EmbeddingProvider): @@ -238,7 +248,7 @@ class OllamaEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -250,23 +260,28 @@ class OllamaEmbedding(EmbeddingProvider): "input": texts, # 新 API 使用 'input' 参数,支持批量 } - async with httpx.AsyncClient(timeout=120) as client: + if client: response = await client.post(url, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post(url, json=payload) + response.raise_for_status() + data = response.json() - # 新 API 返回格式: {"embeddings": [[...], [...], ...]} - embeddings = data.get("embeddings", []) - - results = [] - for i, embedding in enumerate(embeddings): - results.append(EmbeddingResult( - embedding=embedding, - tokens_used=len(texts[i]) // 4, - model=self.model, - )) - - return results + # 新 API 返回格式: {"embeddings": [[...], [...], ...]} + embeddings = data.get("embeddings", []) + + results = [] + for i, embedding in enumerate(embeddings): + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=len(texts[i]) // 4, + model=self.model, + )) + + return results class CohereEmbedding(EmbeddingProvider): @@ -307,7 +322,7 @@ class CohereEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -326,24 +341,29 @@ class CohereEmbedding(EmbeddingProvider): url = f"{self.base_url.rstrip('/')}/embed" - async with httpx.AsyncClient(timeout=60) as client: + if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() - results = [] - # v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...} - embeddings_data = data.get("embeddings", {}) - embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data - - for embedding in embeddings: - results.append(EmbeddingResult( - embedding=embedding, - tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1), - model=self.model, - )) - - return results + results = [] + # v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...} + embeddings_data = data.get("embeddings", {}) + embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data + + for embedding in embeddings: + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1), + model=self.model, + )) + + return results class HuggingFaceEmbedding(EmbeddingProvider): @@ -383,7 +403,7 @@ class HuggingFaceEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -393,7 +413,6 @@ class HuggingFaceEmbedding(EmbeddingProvider): } # 新的 HuggingFace Router URL 格式 - # https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction" payload = { @@ -403,27 +422,29 @@ class HuggingFaceEmbedding(EmbeddingProvider): } } - async with httpx.AsyncClient(timeout=120) as client: + if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() - results = [] - # HuggingFace 返回格式: [[embedding1], [embedding2], ...] - for embedding in data: - # 有时候返回的是嵌套的列表 - if isinstance(embedding, list) and len(embedding) > 0: - if isinstance(embedding[0], list): - # 取平均或第一个 - embedding = embedding[0] - - results.append(EmbeddingResult( - embedding=embedding, - tokens_used=len(texts[len(results)]) // 4, - model=self.model, - )) + results = [] + for embedding in data: + if isinstance(embedding, list) and len(embedding) > 0: + if isinstance(embedding[0], list): + embedding = embedding[0] - return results + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=len(texts[len(results)]) // 4, + model=self.model, + )) + + return results class JinaEmbedding(EmbeddingProvider): @@ -455,7 +476,7 @@ class JinaEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -471,20 +492,25 @@ class JinaEmbedding(EmbeddingProvider): url = f"{self.base_url.rstrip('/')}/embeddings" - async with httpx.AsyncClient(timeout=60) as client: + if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() + else: + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() - results = [] - for item in data.get("data", []): - results.append(EmbeddingResult( - embedding=item["embedding"], - tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), - model=self.model, - )) - - return results + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results class QwenEmbedding(EmbeddingProvider): @@ -537,7 +563,7 @@ class QwenEmbedding(EmbeddingProvider): results = await self.embed_texts([text]) return results[0] - async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] @@ -557,30 +583,30 @@ class QwenEmbedding(EmbeddingProvider): url = f"{self.base_url.rstrip('/')}/embeddings" try: - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post( - url, - headers=headers, - json=payload, - ) - response.raise_for_status() - data = response.json() + if client: + response = await client.post(url, headers=headers, json=payload) + else: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, headers=headers, json=payload) + + response.raise_for_status() + data = response.json() - usage = data.get("usage", {}) or {} - total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 + usage = data.get("usage", {}) or {} + total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 - results: List[EmbeddingResult] = [] - for item in data.get("data", []): - results.append(EmbeddingResult( - embedding=item["embedding"], - tokens_used=total_tokens // max(len(texts), 1), - model=self.model, - )) + results: List[EmbeddingResult] = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=total_tokens // max(len(texts), 1), + model=self.model, + )) - return results + return results except httpx.HTTPStatusError as e: logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}") - raise RuntimeError(f"Qwen embedding API failed: {e.response.status_code}") from e + raise except httpx.RequestError as e: logger.error(f"Qwen embedding network error: {e}") raise RuntimeError(f"Qwen embedding network error: {e}") from e @@ -639,13 +665,32 @@ class EmbeddingService: ) # 🔥 控制并发请求数 (RPS 限制) - self._semaphore = asyncio.Semaphore(30) - - # 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10) is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] + self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10) + self._semaphore = asyncio.Semaphore(self.concurrency) + + # 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10) self.batch_size = 10 if is_remote else 100 - logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})") + + + + # 🔥 共享 HTTP 客户端 + self._client: Optional[httpx.AsyncClient] = None + + logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})") + + async def _get_client(self) -> httpx.AsyncClient: + """获取或创建共享的 AsyncClient""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=120.0) + return self._client + + async def close(self): + """关闭共享客户端""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None def _create_provider( self, @@ -809,27 +854,34 @@ class EmbeddingService: batch: List[str], indices: List[int], cancel_check: Optional[callable] = None, - max_retries: int = 3 + max_retries: Optional[int] = None ) -> List[EmbeddingResult]: """带重试机制的单批次处理""" - for attempt in range(max_retries): + # 优先使用配置中的重试次数 + actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5) + + client = await self._get_client() + + for attempt in range(actual_max_retries): if cancel_check and cancel_check(): raise asyncio.CancelledError("嵌入操作已取消") async with self._semaphore: try: - return await self._provider.embed_texts(batch) + return await self._provider.embed_texts(batch, client=client) except httpx.HTTPStatusError as e: - if e.response.status_code == 429 and attempt < max_retries - 1: - # 429 限流,指数级退避 - wait_time = (2 ** attempt) + 1 - logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})") + if e.response.status_code == 429 and attempt < actual_max_retries - 1: + # 429 限流,使用更保守的指数退避 + # 第一次重试等待 5s, 第二次 10s, 第三次 20s... + wait_time = (2 ** attempt) * 5 + logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})") await asyncio.sleep(wait_time) continue raise except Exception as e: - if attempt < max_retries - 1: - await asyncio.sleep(1) + if attempt < actual_max_retries - 1: + # 普通错误等待 2s + await asyncio.sleep(2) continue raise return [] diff --git a/backend/docker-entrypoint.sh b/backend/docker-entrypoint.sh index ffe6a4f..1a33578 100644 --- a/backend/docker-entrypoint.sh +++ b/backend/docker-entrypoint.sh @@ -55,6 +55,7 @@ exec .venv/bin/gunicorn app.main:app \ --workers 4 \ --worker-class uvicorn.workers.UvicornWorker \ --bind 0.0.0.0:8000 \ - --timeout 120 \ + --timeout 1800 \ + --access-logfile - \ --error-logfile - diff --git a/backend/tests/test_indexer_isolated.py b/backend/tests/test_indexer_isolated.py index ac35e3d..0c4095c 100644 --- a/backend/tests/test_indexer_isolated.py +++ b/backend/tests/test_indexer_isolated.py @@ -29,8 +29,9 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase): # Mock methods that are called during smart_index_directory indexer.initialize = AsyncMock(return_value=(False, "")) - indexer._full_index = AsyncMock() - indexer._incremental_index = AsyncMock() + indexer._full_index = MagicMock() + indexer._incremental_index = MagicMock() + async def mock_gen(*args, **kwargs): yield MagicMock() @@ -69,10 +70,11 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase): vector_store=mock_vector_store ) - indexer._full_index = AsyncMock() - indexer._incremental_index = AsyncMock() + indexer._full_index = MagicMock() + indexer._incremental_index = MagicMock() async def mock_gen(*args, **kwargs): yield MagicMock() + indexer._full_index.side_effect = mock_gen # Test: Existing collection, but needs_rebuild is True (should be FULL)