Adjust the embedding rate to prevent triggering rate limiting.
Build and Push CodeReview / build (push) Waiting to run
Details
Build and Push CodeReview / build (push) Waiting to run
Details
This commit is contained in:
parent
044cd11ad4
commit
180ae67b7e
|
|
@ -106,7 +106,15 @@ class Settings(BaseSettings):
|
||||||
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
|
AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算
|
||||||
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟)
|
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟)
|
||||||
|
|
||||||
|
# 嵌入并发布配置
|
||||||
|
EMBEDDING_CONCURRENCY: int = 2 # 经计算,Batch=10 时并发 2 最接近 1.2M TPM 限流线
|
||||||
|
EMBEDDING_RETRY_MAX: int = 10 # 嵌入模型最大重试次数
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 沙箱配置(必须)
|
# 沙箱配置(必须)
|
||||||
|
|
||||||
SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像
|
SANDBOX_IMAGE: str = "deepaudit/sandbox:latest" # 沙箱 Docker 镜像
|
||||||
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
|
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
|
||||||
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
|
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class EmbeddingProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
"""批量嵌入文本"""
|
"""批量嵌入文本"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -90,7 +90,7 @@ class OpenAIEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -109,20 +109,25 @@ class OpenAIEmbedding(EmbeddingProvider):
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/embeddings"
|
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
if client:
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for item in data.get("data", []):
|
for item in data.get("data", []):
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=item["embedding"],
|
embedding=item["embedding"],
|
||||||
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIEmbedding(EmbeddingProvider):
|
class AzureOpenAIEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -168,7 +173,7 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -187,20 +192,25 @@ class AzureOpenAIEmbedding(EmbeddingProvider):
|
||||||
# Azure URL 格式 - 使用最新 API 版本
|
# Azure URL 格式 - 使用最新 API 版本
|
||||||
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
|
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
if client:
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for item in data.get("data", []):
|
for item in data.get("data", []):
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=item["embedding"],
|
embedding=item["embedding"],
|
||||||
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbedding(EmbeddingProvider):
|
class OllamaEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -238,7 +248,7 @@ class OllamaEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -250,23 +260,28 @@ class OllamaEmbedding(EmbeddingProvider):
|
||||||
"input": texts, # 新 API 使用 'input' 参数,支持批量
|
"input": texts, # 新 API 使用 'input' 参数,支持批量
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=120) as client:
|
if client:
|
||||||
response = await client.post(url, json=payload)
|
response = await client.post(url, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
response = await client.post(url, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
|
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
|
||||||
embeddings = data.get("embeddings", [])
|
embeddings = data.get("embeddings", [])
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, embedding in enumerate(embeddings):
|
for i, embedding in enumerate(embeddings):
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
tokens_used=len(texts[i]) // 4,
|
tokens_used=len(texts[i]) // 4,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class CohereEmbedding(EmbeddingProvider):
|
class CohereEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -307,7 +322,7 @@ class CohereEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -326,24 +341,29 @@ class CohereEmbedding(EmbeddingProvider):
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/embed"
|
url = f"{self.base_url.rstrip('/')}/embed"
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
if client:
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
|
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
|
||||||
embeddings_data = data.get("embeddings", {})
|
embeddings_data = data.get("embeddings", {})
|
||||||
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
|
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
|
||||||
|
|
||||||
for embedding in embeddings:
|
for embedding in embeddings:
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
|
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEmbedding(EmbeddingProvider):
|
class HuggingFaceEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -383,7 +403,7 @@ class HuggingFaceEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -393,7 +413,6 @@ class HuggingFaceEmbedding(EmbeddingProvider):
|
||||||
}
|
}
|
||||||
|
|
||||||
# 新的 HuggingFace Router URL 格式
|
# 新的 HuggingFace Router URL 格式
|
||||||
# https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
|
|
||||||
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
|
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -403,27 +422,29 @@ class HuggingFaceEmbedding(EmbeddingProvider):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=120) as client:
|
if client:
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
# HuggingFace 返回格式: [[embedding1], [embedding2], ...]
|
for embedding in data:
|
||||||
for embedding in data:
|
if isinstance(embedding, list) and len(embedding) > 0:
|
||||||
# 有时候返回的是嵌套的列表
|
if isinstance(embedding[0], list):
|
||||||
if isinstance(embedding, list) and len(embedding) > 0:
|
embedding = embedding[0]
|
||||||
if isinstance(embedding[0], list):
|
|
||||||
# 取平均或第一个
|
|
||||||
embedding = embedding[0]
|
|
||||||
|
|
||||||
results.append(EmbeddingResult(
|
|
||||||
embedding=embedding,
|
|
||||||
tokens_used=len(texts[len(results)]) // 4,
|
|
||||||
model=self.model,
|
|
||||||
))
|
|
||||||
|
|
||||||
return results
|
results.append(EmbeddingResult(
|
||||||
|
embedding=embedding,
|
||||||
|
tokens_used=len(texts[len(results)]) // 4,
|
||||||
|
model=self.model,
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class JinaEmbedding(EmbeddingProvider):
|
class JinaEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -455,7 +476,7 @@ class JinaEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -471,20 +492,25 @@ class JinaEmbedding(EmbeddingProvider):
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/embeddings"
|
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
if client:
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for item in data.get("data", []):
|
for item in data.get("data", []):
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=item["embedding"],
|
embedding=item["embedding"],
|
||||||
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class QwenEmbedding(EmbeddingProvider):
|
class QwenEmbedding(EmbeddingProvider):
|
||||||
|
|
@ -537,7 +563,7 @@ class QwenEmbedding(EmbeddingProvider):
|
||||||
results = await self.embed_texts([text])
|
results = await self.embed_texts([text])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -557,30 +583,30 @@ class QwenEmbedding(EmbeddingProvider):
|
||||||
url = f"{self.base_url.rstrip('/')}/embeddings"
|
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
if client:
|
||||||
response = await client.post(
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
url,
|
else:
|
||||||
headers=headers,
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
json=payload,
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
usage = data.get("usage", {}) or {}
|
usage = data.get("usage", {}) or {}
|
||||||
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
|
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
|
||||||
|
|
||||||
results: List[EmbeddingResult] = []
|
results: List[EmbeddingResult] = []
|
||||||
for item in data.get("data", []):
|
for item in data.get("data", []):
|
||||||
results.append(EmbeddingResult(
|
results.append(EmbeddingResult(
|
||||||
embedding=item["embedding"],
|
embedding=item["embedding"],
|
||||||
tokens_used=total_tokens // max(len(texts), 1),
|
tokens_used=total_tokens // max(len(texts), 1),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
|
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
|
||||||
raise RuntimeError(f"Qwen embedding API failed: {e.response.status_code}") from e
|
raise
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Qwen embedding network error: {e}")
|
logger.error(f"Qwen embedding network error: {e}")
|
||||||
raise RuntimeError(f"Qwen embedding network error: {e}") from e
|
raise RuntimeError(f"Qwen embedding network error: {e}") from e
|
||||||
|
|
@ -639,13 +665,32 @@ class EmbeddingService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 控制并发请求数 (RPS 限制)
|
# 🔥 控制并发请求数 (RPS 限制)
|
||||||
self._semaphore = asyncio.Semaphore(30)
|
|
||||||
|
|
||||||
# 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10)
|
|
||||||
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
|
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
|
||||||
|
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
|
||||||
|
self._semaphore = asyncio.Semaphore(self.concurrency)
|
||||||
|
|
||||||
|
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
|
||||||
self.batch_size = 10 if is_remote else 100
|
self.batch_size = 10 if is_remote else 100
|
||||||
|
|
||||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})")
|
|
||||||
|
|
||||||
|
|
||||||
|
# 🔥 共享 HTTP 客户端
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""获取或创建共享的 AsyncClient"""
|
||||||
|
if self._client is None or self._client.is_closed:
|
||||||
|
self._client = httpx.AsyncClient(timeout=120.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭共享客户端"""
|
||||||
|
if self._client and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
def _create_provider(
|
def _create_provider(
|
||||||
self,
|
self,
|
||||||
|
|
@ -809,27 +854,34 @@ class EmbeddingService:
|
||||||
batch: List[str],
|
batch: List[str],
|
||||||
indices: List[int],
|
indices: List[int],
|
||||||
cancel_check: Optional[callable] = None,
|
cancel_check: Optional[callable] = None,
|
||||||
max_retries: int = 3
|
max_retries: Optional[int] = None
|
||||||
) -> List[EmbeddingResult]:
|
) -> List[EmbeddingResult]:
|
||||||
"""带重试机制的单批次处理"""
|
"""带重试机制的单批次处理"""
|
||||||
for attempt in range(max_retries):
|
# 优先使用配置中的重试次数
|
||||||
|
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
|
||||||
|
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
for attempt in range(actual_max_retries):
|
||||||
if cancel_check and cancel_check():
|
if cancel_check and cancel_check():
|
||||||
raise asyncio.CancelledError("嵌入操作已取消")
|
raise asyncio.CancelledError("嵌入操作已取消")
|
||||||
|
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
return await self._provider.embed_texts(batch)
|
return await self._provider.embed_texts(batch, client=client)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
if e.response.status_code == 429 and attempt < max_retries - 1:
|
if e.response.status_code == 429 and attempt < actual_max_retries - 1:
|
||||||
# 429 限流,指数级退避
|
# 429 限流,使用更保守的指数退避
|
||||||
wait_time = (2 ** attempt) + 1
|
# 第一次重试等待 5s, 第二次 10s, 第三次 20s...
|
||||||
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})")
|
wait_time = (2 ** attempt) * 5
|
||||||
|
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt < max_retries - 1:
|
if attempt < actual_max_retries - 1:
|
||||||
await asyncio.sleep(1)
|
# 普通错误等待 2s
|
||||||
|
await asyncio.sleep(2)
|
||||||
continue
|
continue
|
||||||
raise
|
raise
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ exec .venv/bin/gunicorn app.main:app \
|
||||||
--workers 4 \
|
--workers 4 \
|
||||||
--worker-class uvicorn.workers.UvicornWorker \
|
--worker-class uvicorn.workers.UvicornWorker \
|
||||||
--bind 0.0.0.0:8000 \
|
--bind 0.0.0.0:8000 \
|
||||||
--timeout 120 \
|
--timeout 1800 \
|
||||||
|
|
||||||
--access-logfile - \
|
--access-logfile - \
|
||||||
--error-logfile -
|
--error-logfile -
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,9 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock methods that are called during smart_index_directory
|
# Mock methods that are called during smart_index_directory
|
||||||
indexer.initialize = AsyncMock(return_value=(False, ""))
|
indexer.initialize = AsyncMock(return_value=(False, ""))
|
||||||
indexer._full_index = AsyncMock()
|
indexer._full_index = MagicMock()
|
||||||
indexer._incremental_index = AsyncMock()
|
indexer._incremental_index = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
async def mock_gen(*args, **kwargs):
|
async def mock_gen(*args, **kwargs):
|
||||||
yield MagicMock()
|
yield MagicMock()
|
||||||
|
|
@ -69,10 +70,11 @@ class TestIndexerLogicIsolated(unittest.IsolatedAsyncioTestCase):
|
||||||
vector_store=mock_vector_store
|
vector_store=mock_vector_store
|
||||||
)
|
)
|
||||||
|
|
||||||
indexer._full_index = AsyncMock()
|
indexer._full_index = MagicMock()
|
||||||
indexer._incremental_index = AsyncMock()
|
indexer._incremental_index = MagicMock()
|
||||||
async def mock_gen(*args, **kwargs):
|
async def mock_gen(*args, **kwargs):
|
||||||
yield MagicMock()
|
yield MagicMock()
|
||||||
|
|
||||||
indexer._full_index.side_effect = mock_gen
|
indexer._full_index.side_effect = mock_gen
|
||||||
|
|
||||||
# Test: Existing collection, but needs_rebuild is True (should be FULL)
|
# Test: Existing collection, but needs_rebuild is True (should be FULL)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue