262 lines
8.4 KiB
Python
262 lines
8.4 KiB
Python
"""
|
|
LLM Token 流式输出处理器
|
|
支持多种 LLM 提供商的流式输出
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, Optional, AsyncGenerator, Callable
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TokenChunk:
|
|
"""Token 块"""
|
|
content: str
|
|
token_count: int = 1
|
|
finish_reason: Optional[str] = None
|
|
model: Optional[str] = None
|
|
|
|
# 统计信息
|
|
accumulated_content: str = ""
|
|
total_tokens: int = 0
|
|
|
|
|
|
class TokenStreamer:
|
|
"""
|
|
LLM Token 流式输出处理器
|
|
|
|
最佳实践:
|
|
1. 使用 LiteLLM 的流式 API
|
|
2. 实时发送每个 Token
|
|
3. 跟踪累积内容和 Token 使用
|
|
4. 支持中断和超时
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
on_token: Optional[Callable[[TokenChunk], None]] = None,
|
|
):
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.base_url = base_url
|
|
self.on_token = on_token
|
|
|
|
self._cancelled = False
|
|
self._accumulated_content = ""
|
|
self._total_tokens = 0
|
|
|
|
def cancel(self):
|
|
"""取消流式输出"""
|
|
self._cancelled = True
|
|
|
|
async def stream_completion(
|
|
self,
|
|
messages: list[Dict[str, str]],
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 4096,
|
|
) -> AsyncGenerator[TokenChunk, None]:
|
|
"""
|
|
流式调用 LLM
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
temperature: 温度
|
|
max_tokens: 最大 Token 数
|
|
|
|
Yields:
|
|
TokenChunk: Token 块
|
|
"""
|
|
try:
|
|
import litellm
|
|
|
|
response = await litellm.acompletion(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
stream=True, # 启用流式输出
|
|
)
|
|
|
|
async for chunk in response:
|
|
if self._cancelled:
|
|
break
|
|
|
|
# 提取内容
|
|
content = ""
|
|
finish_reason = None
|
|
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
choice = chunk.choices[0]
|
|
if hasattr(choice, "delta") and choice.delta:
|
|
content = getattr(choice.delta, "content", "") or ""
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
|
|
if content:
|
|
self._accumulated_content += content
|
|
self._total_tokens += 1
|
|
|
|
token_chunk = TokenChunk(
|
|
content=content,
|
|
token_count=1,
|
|
finish_reason=finish_reason,
|
|
model=self.model,
|
|
accumulated_content=self._accumulated_content,
|
|
total_tokens=self._total_tokens,
|
|
)
|
|
|
|
# 回调
|
|
if self.on_token:
|
|
self.on_token(token_chunk)
|
|
|
|
yield token_chunk
|
|
|
|
# 检查是否完成
|
|
if finish_reason:
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Token streaming cancelled")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"Token streaming error: {e}")
|
|
raise
|
|
|
|
async def stream_with_tools(
|
|
self,
|
|
messages: list[Dict[str, str]],
|
|
tools: list[Dict[str, Any]],
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 4096,
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
带工具调用的流式输出
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
tools: 工具定义列表
|
|
temperature: 温度
|
|
max_tokens: 最大 Token 数
|
|
|
|
Yields:
|
|
包含 token 或 tool_call 的字典
|
|
"""
|
|
try:
|
|
import litellm
|
|
|
|
response = await litellm.acompletion(
|
|
model=self.model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
stream=True,
|
|
)
|
|
|
|
# 工具调用累积器
|
|
tool_calls_accumulator: Dict[int, Dict] = {}
|
|
|
|
async for chunk in response:
|
|
if self._cancelled:
|
|
break
|
|
|
|
if not hasattr(chunk, "choices") or not chunk.choices:
|
|
continue
|
|
|
|
choice = chunk.choices[0]
|
|
delta = getattr(choice, "delta", None)
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
|
|
if delta:
|
|
# 处理文本内容
|
|
content = getattr(delta, "content", "") or ""
|
|
if content:
|
|
self._accumulated_content += content
|
|
self._total_tokens += 1
|
|
|
|
yield {
|
|
"type": "token",
|
|
"content": content,
|
|
"accumulated": self._accumulated_content,
|
|
"total_tokens": self._total_tokens,
|
|
}
|
|
|
|
# 处理工具调用
|
|
tool_calls = getattr(delta, "tool_calls", None) or []
|
|
for tool_call in tool_calls:
|
|
idx = tool_call.index
|
|
|
|
if idx not in tool_calls_accumulator:
|
|
tool_calls_accumulator[idx] = {
|
|
"id": tool_call.id or "",
|
|
"name": "",
|
|
"arguments": "",
|
|
}
|
|
|
|
if tool_call.function:
|
|
if tool_call.function.name:
|
|
tool_calls_accumulator[idx]["name"] = tool_call.function.name
|
|
if tool_call.function.arguments:
|
|
tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments
|
|
|
|
yield {
|
|
"type": "tool_call_chunk",
|
|
"index": idx,
|
|
"tool_call": tool_calls_accumulator[idx],
|
|
}
|
|
|
|
# 完成时发送最终工具调用
|
|
if finish_reason == "tool_calls":
|
|
for idx, tool_call in tool_calls_accumulator.items():
|
|
yield {
|
|
"type": "tool_call_complete",
|
|
"index": idx,
|
|
"tool_call": tool_call,
|
|
}
|
|
|
|
if finish_reason:
|
|
yield {
|
|
"type": "finish",
|
|
"reason": finish_reason,
|
|
"accumulated": self._accumulated_content,
|
|
"total_tokens": self._total_tokens,
|
|
}
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Tool streaming cancelled")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tool streaming error: {e}")
|
|
yield {
|
|
"type": "error",
|
|
"error": str(e),
|
|
}
|
|
|
|
def get_accumulated_content(self) -> str:
|
|
"""获取累积内容"""
|
|
return self._accumulated_content
|
|
|
|
def get_total_tokens(self) -> int:
|
|
"""获取总 Token 数"""
|
|
return self._total_tokens
|
|
|
|
def reset(self):
|
|
"""重置状态"""
|
|
self._cancelled = False
|
|
self._accumulated_content = ""
|
|
self._total_tokens = 0
|
|
|