Adjust the embedding rate to prevent triggering rate limiting.
Build and Push CodeReview / build (push) Waiting to run Details

This commit is contained in:
vinland100 2026-01-08 17:35:21 +08:00
parent 044cd11ad4
commit 180ae67b7e
4 changed files with 182 additions and 119 deletions

View File

@ -106,7 +106,15 @@ class Settings(BaseSettings):
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算 AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间30分钟 AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间30分钟
# 嵌入并发布配置
EMBEDDING_CONCURRENCY: int = 2 # 经计算Batch=10 时并发 2 最接近 1.2M TPM 限流线
EMBEDDING_RETRY_MAX: int = 10 # 嵌入模型最大重试次数
# 沙箱配置(必须) # 沙箱配置(必须)
SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像 SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制 SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制 SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制

View File

@ -34,7 +34,7 @@ class EmbeddingProvider(ABC):
pass pass
@abstractmethod @abstractmethod
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
"""批量嵌入文本""" """批量嵌入文本"""
pass pass
@ -90,7 +90,7 @@ class OpenAIEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -109,6 +109,11 @@ class OpenAIEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings" url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload) response = await client.post(url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
@ -168,7 +173,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -187,6 +192,11 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
# Azure URL 格式 - 使用最新 API 版本 # Azure URL 格式 - 使用最新 API 版本
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}" url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload) response = await client.post(url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
@ -238,7 +248,7 @@ class OllamaEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -250,6 +260,11 @@ class OllamaEmbedding(EmbeddingProvider):
"input": texts, # 新 API 使用 'input' 参数,支持批量 "input": texts, # 新 API 使用 'input' 参数,支持批量
} }
if client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client: async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, json=payload) response = await client.post(url, json=payload)
response.raise_for_status() response.raise_for_status()
@ -307,7 +322,7 @@ class CohereEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -326,6 +341,11 @@ class CohereEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embed" url = f"{self.base_url.rstrip('/')}/embed"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload) response = await client.post(url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
@ -383,7 +403,7 @@ class HuggingFaceEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -393,7 +413,6 @@ class HuggingFaceEmbedding(EmbeddingProvider):
} }
# 新的 HuggingFace Router URL 格式 # 新的 HuggingFace Router URL 格式
# https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction" url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
payload = { payload = {
@ -403,18 +422,20 @@ class HuggingFaceEmbedding(EmbeddingProvider):
} }
} }
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client: async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, headers=headers, json=payload) response = await client.post(url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
results = [] results = []
# HuggingFace 返回格式: [[embedding1], [embedding2], ...]
for embedding in data: for embedding in data:
# 有时候返回的是嵌套的列表
if isinstance(embedding, list) and len(embedding) > 0: if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list): if isinstance(embedding[0], list):
# 取平均或第一个
embedding = embedding[0] embedding = embedding[0]
results.append(EmbeddingResult( results.append(EmbeddingResult(
@ -455,7 +476,7 @@ class JinaEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -471,6 +492,11 @@ class JinaEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings" url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload) response = await client.post(url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
@ -537,7 +563,7 @@ class QwenEmbedding(EmbeddingProvider):
results = await self.embed_texts([text]) results = await self.embed_texts([text])
return results[0] return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts: if not texts:
return [] return []
@ -557,12 +583,12 @@ class QwenEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings" url = f"{self.base_url.rstrip('/')}/embeddings"
try: try:
if client:
response = await client.post(url, headers=headers, json=payload)
else:
async with httpx.AsyncClient(timeout=60.0) as client: async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post( response = await client.post(url, headers=headers, json=payload)
url,
headers=headers,
json=payload,
)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@ -580,7 +606,7 @@ class QwenEmbedding(EmbeddingProvider):
return results return results
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}") logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
raise RuntimeError(f"Qwen embedding API failed: {e.response.status_code}") from e raise
except httpx.RequestError as e: except httpx.RequestError as e:
logger.error(f"Qwen embedding network error: {e}") logger.error(f"Qwen embedding network error: {e}")
raise RuntimeError(f"Qwen embedding network error: {e}") from e raise RuntimeError(f"Qwen embedding network error: {e}") from e
@ -639,13 +665,32 @@ class EmbeddingService:
) )
# 🔥 控制并发请求数 (RPS 限制) # 🔥 控制并发请求数 (RPS 限制)
self._semaphore = asyncio.Semaphore(30)
# 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10)
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency)
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
self.batch_size = 10 if is_remote else 100 self.batch_size = 10 if is_remote else 100
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})")
# 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
async def close(self):
"""关闭共享客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
def _create_provider( def _create_provider(
self, self,
@ -809,27 +854,34 @@ class EmbeddingService:
batch: List[str], batch: List[str],
indices: List[int], indices: List[int],
cancel_check: Optional[callable] = None, cancel_check: Optional[callable] = None,
max_retries: int = 3 max_retries: Optional[int] = None
) -> List[EmbeddingResult]: ) -> List[EmbeddingResult]:
"""带重试机制的单批次处理""" """带重试机制的单批次处理"""
for attempt in range(max_retries): # 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
client = await self._get_client()
for attempt in range(actual_max_retries):
if cancel_check and cancel_check(): if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消") raise asyncio.CancelledError("嵌入操作已取消")
async with self._semaphore: async with self._semaphore:
try: try:
return await self._provider.embed_texts(batch) return await self._provider.embed_texts(batch, client=client)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if e.response.status_code == 429 and attempt < max_retries - 1: if e.response.status_code == 429 and attempt < actual_max_retries - 1:
# 429 限流,指数级退避 # 429 限流,使用更保守的指数退避
wait_time = (2 ** attempt) + 1 # 第一次重试等待 5s, 第二次 10s, 第三次 20s...
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})") wait_time = (2 ** attempt) * 5
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
continue continue
raise raise
except Exception as e: except Exception as e:
if attempt < max_retries - 1: if attempt < actual_max_retries - 1:
await asyncio.sleep(1) # 普通错误等待 2s
await asyncio.sleep(2)
continue continue
raise raise
return [] return []

View File

@ -55,6 +55,7 @@ exec .venv/bin/gunicorn app.main:app \
--workers 4 \ --workers 4 \
--worker-class uvicorn.workers.UvicornWorker \ --worker-class uvicorn.workers.UvicornWorker \
--bind 0.0.0.0:8000 \ --bind 0.0.0.0:8000 \
--timeout 120 \ --timeout 1800 \
--access-logfile - \ --access-logfile - \
--error-logfile - --error-logfile -

View File

@ -29,8 +29,9 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
# Mock methods that are called during smart_index_directory # Mock methods that are called during smart_index_directory
indexer.initialize = AsyncMock(return_value=(False, "")) indexer.initialize = AsyncMock(return_value=(False, ""))
indexer._full_index = AsyncMock() indexer._full_index = MagicMock()
indexer._incremental_index = AsyncMock() indexer._incremental_index = MagicMock()
async def mock_gen(*args, **kwargs): async def mock_gen(*args, **kwargs):
yield MagicMock() yield MagicMock()
@ -69,10 +70,11 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
vector_store=mock_vector_store vector_store=mock_vector_store
) )
indexer._full_index = AsyncMock() indexer._full_index = MagicMock()
indexer._incremental_index = AsyncMock() indexer._incremental_index = MagicMock()
async def mock_gen(*args, **kwargs): async def mock_gen(*args, **kwargs):
yield MagicMock() yield MagicMock()
indexer._full_index.side_effect = mock_gen indexer._full_index.side_effect = mock_gen
# Test: Existing collection, but needs_rebuild is True (should be FULL) # Test: Existing collection, but needs_rebuild is True (should be FULL)