Adjust the embedding rate to prevent triggering rate limiting.
Build and Push CodeReview / build (push) Waiting to run Details

This commit is contained in:
vinland100 2026-01-08 17:35:21 +08:00
parent 044cd11ad4
commit 180ae67b7e
4 changed files with 182 additions and 119 deletions

View File

@ -106,7 +106,15 @@ class Settings(BaseSettings):
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间30分钟
# 嵌入并发布配置
EMBEDDING_CONCURRENCY: int = 2 # 经计算Batch=10 时并发 2 最接近 1.2M TPM 限流线
EMBEDDING_RETRY_MAX: int = 10 # 嵌入模型最大重试次数
# 沙箱配置(必须)
SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制

View File

@ -34,7 +34,7 @@ class EmbeddingProvider(ABC):
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
"""批量嵌入文本"""
pass
@ -90,7 +90,7 @@ class OpenAIEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -109,20 +109,25 @@ class OpenAIEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings"
async with httpx.AsyncClient(timeout=60) as client:
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class AzureOpenAIEmbedding(EmbeddingProvider):
@ -168,7 +173,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -187,20 +192,25 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
# Azure URL 格式 - 使用最新 API 版本
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
async with httpx.AsyncClient(timeout=60) as client:
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class OllamaEmbedding(EmbeddingProvider):
@ -238,7 +248,7 @@ class OllamaEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -250,23 +260,28 @@ class OllamaEmbedding(EmbeddingProvider):
"input": texts, # 新 API 使用 'input' 参数,支持批量
}
async with httpx.AsyncClient(timeout=120) as client:
if client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
embeddings = data.get("embeddings", [])
results = []
for i, embedding in enumerate(embeddings):
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[i]) // 4,
model=self.model,
))
return results
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
embeddings = data.get("embeddings", [])
results = []
for i, embedding in enumerate(embeddings):
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[i]) // 4,
model=self.model,
))
return results
class CohereEmbedding(EmbeddingProvider):
@ -307,7 +322,7 @@ class CohereEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -326,24 +341,29 @@ class CohereEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embed"
async with httpx.AsyncClient(timeout=60) as client:
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
embeddings_data = data.get("embeddings", {})
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
for embedding in embeddings:
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
model=self.model,
))
return results
results = []
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
embeddings_data = data.get("embeddings", {})
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
for embedding in embeddings:
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
model=self.model,
))
return results
class HuggingFaceEmbedding(EmbeddingProvider):
@ -383,7 +403,7 @@ class HuggingFaceEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -393,7 +413,6 @@ class HuggingFaceEmbedding(EmbeddingProvider):
}
# 新的 HuggingFace Router URL 格式
# https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
payload = {
@ -403,27 +422,29 @@ class HuggingFaceEmbedding(EmbeddingProvider):
}
}
async with httpx.AsyncClient(timeout=120) as client:
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# HuggingFace 返回格式: [[embedding1], [embedding2], ...]
for embedding in data:
# 有时候返回的是嵌套的列表
if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list):
# 取平均或第一个
embedding = embedding[0]
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[len(results)]) // 4,
model=self.model,
))
results = []
for embedding in data:
if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list):
embedding = embedding[0]
return results
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[len(results)]) // 4,
model=self.model,
))
return results
class JinaEmbedding(EmbeddingProvider):
@ -455,7 +476,7 @@ class JinaEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -471,20 +492,25 @@ class JinaEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings"
async with httpx.AsyncClient(timeout=60) as client:
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class QwenEmbedding(EmbeddingProvider):
@ -537,7 +563,7 @@ class QwenEmbedding(EmbeddingProvider):
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
@ -557,30 +583,30 @@ class QwenEmbedding(EmbeddingProvider):
url = f"{self.base_url.rstrip('/')}/embeddings"
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
url,
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if client:
response = await client.post(url, headers=headers, json=payload)
else:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
usage = data.get("usage", {}) or {}
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
usage = data.get("usage", {}) or {}
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
results: List[EmbeddingResult] = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=total_tokens // max(len(texts), 1),
model=self.model,
))
results: List[EmbeddingResult] = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=total_tokens // max(len(texts), 1),
model=self.model,
))
return results
return results
except httpx.HTTPStatusError as e:
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
raise RuntimeError(f"Qwen embedding API failed: {e.response.status_code}") from e
raise
except httpx.RequestError as e:
logger.error(f"Qwen embedding network error: {e}")
raise RuntimeError(f"Qwen embedding network error: {e}") from e
@ -639,13 +665,32 @@ class EmbeddingService:
)
# 🔥 控制并发请求数 (RPS 限制)
self._semaphore = asyncio.Semaphore(30)
# 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10)
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency)
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
self.batch_size = 10 if is_remote else 100
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})")
# 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
async def close(self):
"""关闭共享客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
def _create_provider(
self,
@ -809,27 +854,34 @@ class EmbeddingService:
batch: List[str],
indices: List[int],
cancel_check: Optional[callable] = None,
max_retries: int = 3
max_retries: Optional[int] = None
) -> List[EmbeddingResult]:
"""带重试机制的单批次处理"""
for attempt in range(max_retries):
# 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
client = await self._get_client()
for attempt in range(actual_max_retries):
if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消")
async with self._semaphore:
try:
return await self._provider.embed_texts(batch)
return await self._provider.embed_texts(batch, client=client)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429 and attempt < max_retries - 1:
# 429 限流,指数级退避
wait_time = (2 ** attempt) + 1
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})")
if e.response.status_code == 429 and attempt < actual_max_retries - 1:
# 429 限流,使用更保守的指数退避
# 第一次重试等待 5s, 第二次 10s, 第三次 20s...
wait_time = (2 ** attempt) * 5
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})")
await asyncio.sleep(wait_time)
continue
raise
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(1)
if attempt < actual_max_retries - 1:
# 普通错误等待 2s
await asyncio.sleep(2)
continue
raise
return []

View File

@ -55,6 +55,7 @@ exec .venv/bin/gunicorn app.main:app \
--workers 4 \
--worker-class uvicorn.workers.UvicornWorker \
--bind 0.0.0.0:8000 \
--timeout 120 \
--timeout 1800 \
--access-logfile - \
--error-logfile -

View File

@ -29,8 +29,9 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
# Mock methods that are called during smart_index_directory
indexer.initialize = AsyncMock(return_value=(False, ""))
indexer._full_index = AsyncMock()
indexer._incremental_index = AsyncMock()
indexer._full_index = MagicMock()
indexer._incremental_index = MagicMock()
async def mock_gen(*args, **kwargs):
yield MagicMock()
@ -69,10 +70,11 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
vector_store=mock_vector_store
)
indexer._full_index = AsyncMock()
indexer._incremental_index = AsyncMock()
indexer._full_index = MagicMock()
indexer._incremental_index = MagicMock()
async def mock_gen(*args, **kwargs):
yield MagicMock()
indexer._full_index.side_effect = mock_gen
# Test: Existing collection, but needs_rebuild is True (should be FULL)