CodeReview/backend/app/services/agent/streaming/token_streamer.py

262 lines
8.4 KiB
Python

"""
LLM Token 流式输出处理器
支持多种 LLM 提供商的流式输出
"""
import asyncio
import logging
from typing import Any, Dict, Optional, AsyncGenerator, Callable
from dataclasses import dataclass
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
@dataclass
class TokenChunk:
"""Token 块"""
content: str
token_count: int = 1
finish_reason: Optional[str] = None
model: Optional[str] = None
# 统计信息
accumulated_content: str = ""
total_tokens: int = 0
class TokenStreamer:
"""
LLM Token 流式输出处理器
最佳实践:
1. 使用 LiteLLM 的流式 API
2. 实时发送每个 Token
3. 跟踪累积内容和 Token 使用
4. 支持中断和超时
"""
def __init__(
self,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
on_token: Optional[Callable[[TokenChunk], None]] = None,
):
self.model = model
self.api_key = api_key
self.base_url = base_url
self.on_token = on_token
self._cancelled = False
self._accumulated_content = ""
self._total_tokens = 0
def cancel(self):
"""取消流式输出"""
self._cancelled = True
async def stream_completion(
self,
messages: list[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> AsyncGenerator[TokenChunk, None]:
"""
流式调用 LLM
Args:
messages: 消息列表
temperature: 温度
max_tokens: 最大 Token 数
Yields:
TokenChunk: Token 块
"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
stream=True, # 启用流式输出
)
async for chunk in response:
if self._cancelled:
break
# 提取内容
content = ""
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = getattr(choice.delta, "content", "") or ""
finish_reason = getattr(choice, "finish_reason", None)
if content:
self._accumulated_content += content
self._total_tokens += 1
token_chunk = TokenChunk(
content=content,
token_count=1,
finish_reason=finish_reason,
model=self.model,
accumulated_content=self._accumulated_content,
total_tokens=self._total_tokens,
)
# 回调
if self.on_token:
self.on_token(token_chunk)
yield token_chunk
# 检查是否完成
if finish_reason:
break
except asyncio.CancelledError:
logger.info("Token streaming cancelled")
raise
except Exception as e:
logger.error(f"Token streaming error: {e}")
raise
async def stream_with_tools(
self,
messages: list[Dict[str, str]],
tools: list[Dict[str, Any]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
带工具调用的流式输出
Args:
messages: 消息列表
tools: 工具定义列表
temperature: 温度
max_tokens: 最大 Token 数
Yields:
包含 token 或 tool_call 的字典
"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
stream=True,
)
# 工具调用累积器
tool_calls_accumulator: Dict[int, Dict] = {}
async for chunk in response:
if self._cancelled:
break
if not hasattr(chunk, "choices") or not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
finish_reason = getattr(choice, "finish_reason", None)
if delta:
# 处理文本内容
content = getattr(delta, "content", "") or ""
if content:
self._accumulated_content += content
self._total_tokens += 1
yield {
"type": "token",
"content": content,
"accumulated": self._accumulated_content,
"total_tokens": self._total_tokens,
}
# 处理工具调用
tool_calls = getattr(delta, "tool_calls", None) or []
for tool_call in tool_calls:
idx = tool_call.index
if idx not in tool_calls_accumulator:
tool_calls_accumulator[idx] = {
"id": tool_call.id or "",
"name": "",
"arguments": "",
}
if tool_call.function:
if tool_call.function.name:
tool_calls_accumulator[idx]["name"] = tool_call.function.name
if tool_call.function.arguments:
tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments
yield {
"type": "tool_call_chunk",
"index": idx,
"tool_call": tool_calls_accumulator[idx],
}
# 完成时发送最终工具调用
if finish_reason == "tool_calls":
for idx, tool_call in tool_calls_accumulator.items():
yield {
"type": "tool_call_complete",
"index": idx,
"tool_call": tool_call,
}
if finish_reason:
yield {
"type": "finish",
"reason": finish_reason,
"accumulated": self._accumulated_content,
"total_tokens": self._total_tokens,
}
break
except asyncio.CancelledError:
logger.info("Tool streaming cancelled")
raise
except Exception as e:
logger.error(f"Tool streaming error: {e}")
yield {
"type": "error",
"error": str(e),
}
def get_accumulated_content(self) -> str:
"""获取累积内容"""
return self._accumulated_content
def get_total_tokens(self) -> int:
"""获取总 Token 数"""
return self._total_tokens
def reset(self):
"""重置状态"""
self._cancelled = False
self._accumulated_content = ""
self._total_tokens = 0