feat: Improve streaming LLM token usage reporting by adding input estimation, requesting usage via `stream_options`, and providing fallback estimation.
This commit is contained in:
parent
e13218a33e
commit
31dc476015
|
|
@ -229,6 +229,9 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
|
|
||||||
|
# 🔥 估算输入 token 数量(用于在无法获取真实 usage 时进行估算)
|
||||||
|
input_tokens_estimate = sum(estimate_tokens(msg["content"]) for msg in messages)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": self._litellm_model,
|
"model": self._litellm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|
@ -238,6 +241,11 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
"stream": True, # 启用流式输出
|
"stream": True, # 启用流式输出
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 🔥 对于支持的模型,请求在流式输出中包含 usage 信息
|
||||||
|
# OpenAI API 支持 stream_options
|
||||||
|
if self.config.provider in [LLMProvider.OPENAI, LLMProvider.DEEPSEEK]:
|
||||||
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
if self.config.api_key and self.config.api_key != "ollama":
|
if self.config.api_key and self.config.api_key != "ollama":
|
||||||
kwargs["api_key"] = self.config.api_key
|
kwargs["api_key"] = self.config.api_key
|
||||||
|
|
||||||
|
|
@ -247,11 +255,21 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
kwargs["timeout"] = self.config.timeout
|
kwargs["timeout"] = self.config.timeout
|
||||||
|
|
||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
|
final_usage = None # 🔥 存储最终的 usage 信息
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(**kwargs)
|
response = await litellm.acompletion(**kwargs)
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
# 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含)
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
final_usage = {
|
||||||
|
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||||
|
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||||
|
"total_tokens": chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
logger.debug(f"Got usage from chunk: {final_usage}")
|
||||||
|
|
||||||
if not chunk.choices:
|
if not chunk.choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -269,27 +287,36 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
# 流式完成
|
# 流式完成
|
||||||
usage = None
|
# 🔥 如果没有从 chunk 获取到 usage,进行估算
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
if not final_usage:
|
||||||
usage = {
|
output_tokens_estimate = estimate_tokens(accumulated_content)
|
||||||
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
final_usage = {
|
||||||
"completion_tokens": chunk.usage.completion_tokens or 0,
|
"prompt_tokens": input_tokens_estimate,
|
||||||
"total_tokens": chunk.usage.total_tokens or 0,
|
"completion_tokens": output_tokens_estimate,
|
||||||
|
"total_tokens": input_tokens_estimate + output_tokens_estimate,
|
||||||
}
|
}
|
||||||
|
logger.debug(f"Estimated usage: {final_usage}")
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"type": "done",
|
"type": "done",
|
||||||
"content": accumulated_content,
|
"content": accumulated_content,
|
||||||
"usage": usage,
|
"usage": final_usage,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 🔥 即使出错,也尝试返回估算的 usage
|
||||||
|
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||||
yield {
|
yield {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"accumulated": accumulated_content,
|
"accumulated": accumulated_content,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_tokens_estimate,
|
||||||
|
"completion_tokens": output_tokens_estimate,
|
||||||
|
"total_tokens": input_tokens_estimate + output_tokens_estimate,
|
||||||
|
} if accumulated_content else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
async def validate_config(self) -> bool:
|
||||||
|
|
|
||||||
1038
backend/test_msg.md
1038
backend/test_msg.md
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue