feat(agent): enhance streaming with in-memory event manager and fallback polling
- Implement dual-mode streaming: prioritize in-memory EventManager for running tasks with thinking_token support - Add fallback to database polling for completed tasks without thinking_token replay capability - Introduce SSE event formatter utility for consistent event serialization across streaming modes - Add 10ms micro-delay for thinking_token events to ensure proper TCP packet separation and frontend incremental rendering - Refactor stream_agent_with_thinking endpoint to support both runtime and historical event streaming - Update event filtering logic to handle both in-memory and database event sources - Improve logging with debug markers for thinking_token tracking and stream mode selection - Optimize polling intervals: 0.3s for running tasks, 2.0s for completed tasks - Reduce idle timeout from 10 minutes to 1 minute for completed task streams - Update frontend useAgentStream hook to handle unified event format from dual-mode streaming - Enhance AgentAudit UI to properly display streamed events from both sources
This commit is contained in:
parent
70776ee5fd
commit
147dfbaf5e
|
|
@ -601,25 +601,13 @@ async def stream_agent_with_thinking(
|
|||
增强版事件流 (SSE)
|
||||
|
||||
支持:
|
||||
- LLM 思考过程的 Token 级流式输出
|
||||
- LLM 思考过程的 Token 级流式输出 (仅运行时)
|
||||
- 工具调用的详细输入/输出
|
||||
- 节点执行状态
|
||||
- 发现事件
|
||||
|
||||
事件类型:
|
||||
- thinking_start: LLM 开始思考
|
||||
- thinking_token: LLM 输出 Token
|
||||
- thinking_end: LLM 思考结束
|
||||
- tool_call_start: 工具调用开始
|
||||
- tool_call_end: 工具调用结束
|
||||
- node_start: 节点开始
|
||||
- node_end: 节点结束
|
||||
- finding_new: 新发现
|
||||
- finding_verified: 验证通过
|
||||
- progress: 进度更新
|
||||
- task_complete: 任务完成
|
||||
- task_error: 任务错误
|
||||
- heartbeat: 心跳
|
||||
优先使用内存中的事件队列 (支持 thinking_token),
|
||||
如果任务未在运行,则回退到数据库轮询 (不支持 thinking_token 复盘)。
|
||||
"""
|
||||
task = await db.get(AgentTask, task_id)
|
||||
if not task:
|
||||
|
|
@ -629,119 +617,156 @@ async def stream_agent_with_thinking(
|
|||
if not project or project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||
|
||||
# 定义 SSE 格式化函数
|
||||
def format_sse_event(event_data: Dict[str, Any]) -> str:
|
||||
"""格式化为 SSE 事件"""
|
||||
event_type = event_data.get("event_type") or event_data.get("type")
|
||||
|
||||
# 统一字段
|
||||
if "type" not in event_data:
|
||||
event_data["type"] = event_type
|
||||
|
||||
return f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def enhanced_event_generator():
|
||||
"""生成增强版 SSE 事件流"""
|
||||
last_sequence = after_sequence
|
||||
poll_interval = 0.3 # 更短的轮询间隔以支持流式
|
||||
heartbeat_interval = 15 # 心跳间隔
|
||||
max_idle = 600 # 10 分钟无事件后关闭
|
||||
idle_time = 0
|
||||
last_heartbeat = 0
|
||||
# 1. 检查任务是否在运行中 (内存)
|
||||
runner = _running_tasks.get(task_id)
|
||||
|
||||
# 事件类型过滤
|
||||
skip_types = set()
|
||||
if not include_thinking:
|
||||
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
|
||||
if not include_tool_calls:
|
||||
skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"])
|
||||
|
||||
while True:
|
||||
if runner:
|
||||
logger.info(f"Stream {task_id}: Using in-memory event manager")
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# 查询新事件
|
||||
result = await session.execute(
|
||||
select(AgentEvent)
|
||||
.where(AgentEvent.task_id == task_id)
|
||||
.where(AgentEvent.sequence > last_sequence)
|
||||
.order_by(AgentEvent.sequence)
|
||||
.limit(100)
|
||||
)
|
||||
events = result.scalars().all()
|
||||
# 使用 EventManager 的流式接口
|
||||
# 过滤选项
|
||||
skip_types = set()
|
||||
if not include_thinking:
|
||||
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
|
||||
if not include_tool_calls:
|
||||
skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"])
|
||||
|
||||
async for event in runner.event_manager.stream_events(task_id, after_sequence=after_sequence):
|
||||
event_type = event.get("event_type")
|
||||
|
||||
# 获取任务状态
|
||||
current_task = await session.get(AgentTask, task_id)
|
||||
task_status = current_task.status if current_task else None
|
||||
|
||||
if events:
|
||||
idle_time = 0
|
||||
for event in events:
|
||||
last_sequence = event.sequence
|
||||
if event_type in skip_types:
|
||||
continue
|
||||
|
||||
# 🔥 Debug: 记录 thinking_token 事件
|
||||
if event_type == "thinking_token":
|
||||
token = event.get("metadata", {}).get("token", "")[:20]
|
||||
logger.debug(f"Stream {task_id}: Sending thinking_token: '{token}...'")
|
||||
|
||||
# 获取事件类型字符串(event_type 已经是字符串)
|
||||
event_type = str(event.event_type)
|
||||
|
||||
# 过滤事件
|
||||
if event_type in skip_types:
|
||||
continue
|
||||
|
||||
# 构建事件数据
|
||||
data = {
|
||||
"id": event.id,
|
||||
"type": event_type,
|
||||
"phase": str(event.phase) if event.phase else None,
|
||||
"message": event.message,
|
||||
"sequence": event.sequence,
|
||||
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if include_tool_calls and event.tool_name:
|
||||
data["tool"] = {
|
||||
"name": event.tool_name,
|
||||
"input": event.tool_input,
|
||||
"output": event.tool_output,
|
||||
"duration_ms": event.tool_duration_ms,
|
||||
}
|
||||
|
||||
# 添加元数据
|
||||
if event.event_metadata:
|
||||
data["metadata"] = event.event_metadata
|
||||
|
||||
# 添加 Token 使用
|
||||
if event.tokens_used:
|
||||
data["tokens_used"] = event.tokens_used
|
||||
|
||||
# 使用标准 SSE 格式
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
idle_time += poll_interval
|
||||
|
||||
# 检查任务是否结束
|
||||
if task_status:
|
||||
status_str = str(task_status)
|
||||
if status_str in ["completed", "failed", "cancelled"]:
|
||||
end_data = {
|
||||
"type": "task_end",
|
||||
"status": status_str,
|
||||
"message": f"任务{'完成' if status_str == 'completed' else '结束'}",
|
||||
}
|
||||
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
|
||||
# 发送心跳
|
||||
last_heartbeat += poll_interval
|
||||
if last_heartbeat >= heartbeat_interval:
|
||||
last_heartbeat = 0
|
||||
heartbeat_data = {
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"last_sequence": last_sequence,
|
||||
}
|
||||
yield f"event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n"
|
||||
|
||||
# 检查空闲超时
|
||||
if idle_time >= max_idle:
|
||||
timeout_data = {"type": "timeout", "message": "连接超时"}
|
||||
yield f"event: timeout\ndata: {json.dumps(timeout_data)}\n\n"
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# 格式化并 yield
|
||||
yield format_sse_event(event)
|
||||
|
||||
# 🔥 CRITICAL: 为 thinking_token 添加微小延迟
|
||||
# 确保事件在不同的 TCP 包中发送,让前端能够逐个处理
|
||||
# 没有这个延迟,所有 token 会在一次 read() 中被接收,导致 React 批量更新
|
||||
if event_type == "thinking_token":
|
||||
await asyncio.sleep(0.01) # 10ms 延迟
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}")
|
||||
error_data = {"type": "error", "message": str(e)}
|
||||
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
|
||||
break
|
||||
logger.error(f"In-memory stream error: {e}")
|
||||
err_data = {"type": "error", "message": str(e)}
|
||||
yield format_sse_event(err_data)
|
||||
|
||||
else:
|
||||
logger.info(f"Stream {task_id}: Task not running, falling back to DB polling")
|
||||
# 2. 回退到数据库轮询 (无法获取 thinking_token)
|
||||
last_sequence = after_sequence
|
||||
poll_interval = 2.0 # 完成的任务轮询可以慢一点
|
||||
heartbeat_interval = 15
|
||||
max_idle = 60 # 1分钟无事件关闭
|
||||
idle_time = 0
|
||||
last_heartbeat = 0
|
||||
|
||||
skip_types = set()
|
||||
if not include_thinking:
|
||||
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
|
||||
|
||||
while True:
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# 查询新事件
|
||||
result = await session.execute(
|
||||
select(AgentEvent)
|
||||
.where(AgentEvent.task_id == task_id)
|
||||
.where(AgentEvent.sequence > last_sequence)
|
||||
.order_by(AgentEvent.sequence)
|
||||
.limit(100)
|
||||
)
|
||||
events = result.scalars().all()
|
||||
|
||||
# 获取任务状态
|
||||
current_task = await session.get(AgentTask, task_id)
|
||||
task_status = current_task.status if current_task else None
|
||||
|
||||
if events:
|
||||
idle_time = 0
|
||||
for event in events:
|
||||
last_sequence = event.sequence
|
||||
event_type = str(event.event_type)
|
||||
|
||||
if event_type in skip_types:
|
||||
continue
|
||||
|
||||
# 构建数据
|
||||
data = {
|
||||
"id": event.id,
|
||||
"type": event_type,
|
||||
"phase": str(event.phase) if event.phase else None,
|
||||
"message": event.message,
|
||||
"sequence": event.sequence,
|
||||
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||
}
|
||||
|
||||
# 添加详情
|
||||
if include_tool_calls and event.tool_name:
|
||||
data["tool"] = {
|
||||
"name": event.tool_name,
|
||||
"input": event.tool_input,
|
||||
"output": event.tool_output,
|
||||
"duration_ms": event.tool_duration_ms,
|
||||
}
|
||||
|
||||
if event.event_metadata:
|
||||
data["metadata"] = event.event_metadata
|
||||
|
||||
if event.tokens_used:
|
||||
data["tokens_used"] = event.tokens_used
|
||||
|
||||
yield format_sse_event(data)
|
||||
else:
|
||||
idle_time += poll_interval
|
||||
|
||||
# 检查是否应该结束
|
||||
if task_status:
|
||||
status_str = str(task_status)
|
||||
# 如果任务已完成且没有新事件,结束流
|
||||
if status_str in ["completed", "failed", "cancelled"]:
|
||||
end_data = {
|
||||
"type": "task_end",
|
||||
"status": status_str,
|
||||
"message": f"任务已{status_str}"
|
||||
}
|
||||
yield format_sse_event(end_data)
|
||||
break
|
||||
|
||||
# 心跳
|
||||
last_heartbeat += poll_interval
|
||||
if last_heartbeat >= heartbeat_interval:
|
||||
last_heartbeat = 0
|
||||
yield format_sse_event({"type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()})
|
||||
|
||||
# 超时
|
||||
if idle_time >= max_idle:
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DB poll stream error: {e}")
|
||||
yield format_sse_event({"type": "error", "message": str(e)})
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
enhanced_event_generator(),
|
||||
|
|
|
|||
|
|
@ -187,6 +187,10 @@ class AnalysisAgent(BaseAgent):
|
|||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
elif not re.search(r'Action:|Final Answer:', response):
|
||||
# 🔥 Fallback: If no markers found, treat the whole response as Thought
|
||||
if response.strip():
|
||||
step.thought = response.strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
|
|
@ -330,7 +334,17 @@ class AnalysisAgent(BaseAgent):
|
|||
break
|
||||
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
|
||||
# 🔥 Handle empty LLM response to prevent loops
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
|
||||
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示")
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output your Thought and Action.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
|
@ -406,6 +420,11 @@ class AnalysisAgent(BaseAgent):
|
|||
# 标准化发现
|
||||
standardized_findings = []
|
||||
for finding in all_findings:
|
||||
# 确保 finding 是字典
|
||||
if not isinstance(finding, dict):
|
||||
logger.warning(f"Skipping invalid finding (not a dict): {finding}")
|
||||
continue
|
||||
|
||||
standardized = {
|
||||
"vulnerability_type": finding.get("vulnerability_type", "other"),
|
||||
"severity": finding.get("severity", "medium"),
|
||||
|
|
|
|||
|
|
@ -409,17 +409,38 @@ class BaseAgent(ABC):
|
|||
"""发射事件"""
|
||||
if self.event_emitter:
|
||||
from ..event_manager import AgentEventData
|
||||
|
||||
# 准备 metadata
|
||||
metadata = kwargs.get("metadata", {}) or {}
|
||||
if "agent_name" not in metadata:
|
||||
metadata["agent_name"] = self.name
|
||||
|
||||
# 分离已知字段和未知字段
|
||||
known_fields = {
|
||||
"phase", "tool_name", "tool_input", "tool_output",
|
||||
"tool_duration_ms", "finding_id", "tokens_used"
|
||||
}
|
||||
|
||||
event_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if k in known_fields:
|
||||
event_kwargs[k] = v
|
||||
elif k != "metadata":
|
||||
# 将未知字段放入 metadata
|
||||
metadata[k] = v
|
||||
|
||||
await self.event_emitter.emit(AgentEventData(
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
**kwargs
|
||||
metadata=metadata,
|
||||
**event_kwargs
|
||||
))
|
||||
|
||||
# ============ LLM 思考相关事件 ============
|
||||
|
||||
async def emit_thinking(self, message: str):
|
||||
"""发射 LLM 思考事件"""
|
||||
await self.emit_event("thinking", f"[{self.name}] {message}")
|
||||
await self.emit_event("thinking", message)
|
||||
|
||||
async def emit_llm_start(self, iteration: int):
|
||||
"""发射 LLM 开始思考事件"""
|
||||
|
|
@ -444,7 +465,7 @@ class BaseAgent(ABC):
|
|||
|
||||
async def emit_thinking_start(self):
|
||||
"""发射开始思考事件(流式输出用)"""
|
||||
await self.emit_event("thinking_start", f"[{self.name}] 开始思考...")
|
||||
await self.emit_event("thinking_start", "开始思考...")
|
||||
|
||||
async def emit_thinking_token(self, token: str, accumulated: str):
|
||||
"""发射思考 token 事件(流式输出用)"""
|
||||
|
|
@ -461,7 +482,7 @@ class BaseAgent(ABC):
|
|||
"""发射思考结束事件(流式输出用)"""
|
||||
await self.emit_event(
|
||||
"thinking_end",
|
||||
f"[{self.name}] 思考完成",
|
||||
"思考完成",
|
||||
metadata={"accumulated": full_response}
|
||||
)
|
||||
|
||||
|
|
@ -690,6 +711,9 @@ class BaseAgent(ABC):
|
|||
token = chunk["content"]
|
||||
accumulated = chunk["accumulated"]
|
||||
await self.emit_thinking_token(token, accumulated)
|
||||
# 🔥 CRITICAL: 让出控制权给事件循环,让 SSE 有机会发送事件
|
||||
# 如果不这样做,所有 token 会在循环结束后一起发送
|
||||
await asyncio.sleep(0)
|
||||
|
||||
elif chunk["type"] == "done":
|
||||
accumulated = chunk["content"]
|
||||
|
|
|
|||
|
|
@ -421,6 +421,9 @@ class OrchestratorAgent(BaseAgent):
|
|||
### 发现摘要
|
||||
"""
|
||||
for i, f in enumerate(findings[:10]): # 最多显示 10 个
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
|
||||
observation += f"""
|
||||
{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')}
|
||||
- 类型: {f.get('vulnerability_type', 'unknown')}
|
||||
|
|
@ -452,6 +455,9 @@ class OrchestratorAgent(BaseAgent):
|
|||
type_counts = {}
|
||||
|
||||
for f in self._all_findings:
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
|
||||
sev = f.get("severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
|
|
@ -475,7 +481,8 @@ class OrchestratorAgent(BaseAgent):
|
|||
|
||||
summary += "\n### 详细列表\n"
|
||||
for i, f in enumerate(self._all_findings):
|
||||
summary += f"{i+1}. [{f.get('severity')}] {f.get('title')} ({f.get('file_path')})\n"
|
||||
if isinstance(f, dict):
|
||||
summary += f"{i+1}. [{f.get('severity')}] {f.get('title')} ({f.get('file_path')})\n"
|
||||
|
||||
return summary
|
||||
|
||||
|
|
@ -484,8 +491,9 @@ class OrchestratorAgent(BaseAgent):
|
|||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
|
||||
for f in self._all_findings:
|
||||
sev = f.get("severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
if isinstance(f, dict):
|
||||
sev = f.get("severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
return {
|
||||
"total_findings": len(self._all_findings),
|
||||
|
|
|
|||
|
|
@ -157,6 +157,11 @@ class ReconAgent(BaseAgent):
|
|||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
elif not re.search(r'Action:|Final Answer:', response):
|
||||
# 🔥 Fallback: If no markers found, treat the whole response as Thought
|
||||
# This prevents empty steps loops "Decision: Continue Thinking"
|
||||
if response.strip():
|
||||
step.thought = response.strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
|
|
@ -170,6 +175,12 @@ class ReconAgent(BaseAgent):
|
|||
answer_text,
|
||||
default={"raw_answer": answer_text}
|
||||
)
|
||||
# 确保 findings 格式正确
|
||||
if "initial_findings" in step.final_answer:
|
||||
step.final_answer["initial_findings"] = [
|
||||
f for f in step.final_answer["initial_findings"]
|
||||
if isinstance(f, dict)
|
||||
]
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
|
|
@ -256,6 +267,16 @@ class ReconAgent(BaseAgent):
|
|||
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 🔥 Handle empty LLM response to prevent loops
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
|
||||
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示")
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output your Thought and Action.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
|
|
|||
|
|
@ -169,6 +169,10 @@ class VerificationAgent(BaseAgent):
|
|||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
elif not re.search(r'Action:|Final Answer:', response):
|
||||
# 🔥 Fallback: If no markers found, treat the whole response as Thought
|
||||
if response.strip():
|
||||
step.thought = response.strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
|
|
@ -337,7 +341,17 @@ class VerificationAgent(BaseAgent):
|
|||
break
|
||||
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
|
||||
# 🔥 Handle empty LLM response to prevent loops
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
|
||||
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示")
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output your Thought and Action.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
|
|
|||
|
|
@ -300,16 +300,19 @@ class EventManager:
|
|||
}
|
||||
|
||||
# 保存到数据库(跳过高频事件如 thinking_token)
|
||||
skip_db_events = {"thinking_token", "thinking_start", "thinking_end"}
|
||||
skip_db_events = {"thinking_token"}
|
||||
if self.db_session_factory and event_type not in skip_db_events:
|
||||
try:
|
||||
await self._save_event_to_db(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save event to database: {e}")
|
||||
|
||||
# 推送到队列
|
||||
# 推送到队列(非阻塞)
|
||||
if task_id in self._event_queues:
|
||||
await self._event_queues[task_id].put(event_data)
|
||||
try:
|
||||
self._event_queues[task_id].put_nowait(event_data)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Event queue full for task {task_id}, dropping event: {event_type}")
|
||||
|
||||
# 调用回调
|
||||
if task_id in self._event_callbacks:
|
||||
|
|
@ -348,9 +351,10 @@ class EventManager:
|
|||
await db.commit()
|
||||
|
||||
def create_queue(self, task_id: str) -> asyncio.Queue:
|
||||
"""创建事件队列"""
|
||||
"""创建或获取事件队列"""
|
||||
if task_id not in self._event_queues:
|
||||
self._event_queues[task_id] = asyncio.Queue()
|
||||
# 🔥 使用较大的队列容量,缓存更多 token 事件
|
||||
self._event_queues[task_id] = asyncio.Queue(maxsize=1000)
|
||||
return self._event_queues[task_id]
|
||||
|
||||
def remove_queue(self, task_id: str):
|
||||
|
|
@ -398,13 +402,43 @@ class EventManager:
|
|||
task_id: str,
|
||||
after_sequence: int = 0,
|
||||
) -> AsyncGenerator[Dict, None]:
|
||||
"""流式获取事件"""
|
||||
queue = self.create_queue(task_id)
|
||||
"""流式获取事件
|
||||
|
||||
# 先发送历史事件
|
||||
history = await self.get_events(task_id, after_sequence)
|
||||
for event in history:
|
||||
yield event
|
||||
🔥 重要: 此方法会先排空队列中已缓存的事件(在 SSE 连接前产生的),
|
||||
然后继续实时推送新事件。
|
||||
"""
|
||||
# 获取现有队列(由 AgentRunner 在初始化时创建)
|
||||
queue = self._event_queues.get(task_id)
|
||||
|
||||
if not queue:
|
||||
# 如果队列不存在,创建一个新的(回退逻辑)
|
||||
queue = self.create_queue(task_id)
|
||||
logger.warning(f"Queue not found for task {task_id}, created new one")
|
||||
|
||||
# 🔥 先排空队列中已缓存的事件(这些是在 SSE 连接前产生的)
|
||||
buffered_count = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
buffered_event = queue.get_nowait()
|
||||
buffered_count += 1
|
||||
yield buffered_event
|
||||
|
||||
# 🔥 为所有缓存事件添加延迟,确保不会一起输出
|
||||
event_type = buffered_event.get("event_type")
|
||||
if event_type == "thinking_token":
|
||||
await asyncio.sleep(0.015) # 15ms for tokens
|
||||
else:
|
||||
await asyncio.sleep(0.005) # 5ms for other events
|
||||
|
||||
# 检查是否是结束事件
|
||||
if event_type in ["task_complete", "task_error", "task_cancel"]:
|
||||
logger.info(f"Task {task_id} already completed, sent {buffered_count} buffered events")
|
||||
return
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if buffered_count > 0:
|
||||
logger.info(f"Drained {buffered_count} buffered events for task {task_id}")
|
||||
|
||||
# 然后实时推送新事件
|
||||
try:
|
||||
|
|
@ -413,6 +447,10 @@ class EventManager:
|
|||
event = await asyncio.wait_for(queue.get(), timeout=30)
|
||||
yield event
|
||||
|
||||
# 🔥 为 thinking_token 添加微延迟确保流式效果
|
||||
if event.get("event_type") == "thinking_token":
|
||||
await asyncio.sleep(0.01) # 10ms
|
||||
|
||||
# 检查是否是结束事件
|
||||
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
|
||||
break
|
||||
|
|
@ -421,8 +459,10 @@ class EventManager:
|
|||
# 发送心跳
|
||||
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
finally:
|
||||
self.remove_queue(task_id)
|
||||
except GeneratorExit:
|
||||
# SSE 连接断开
|
||||
logger.debug(f"SSE stream closed for task {task_id}")
|
||||
# 🔥 不要移除队列,让 AgentRunner 管理队列的生命周期
|
||||
|
||||
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
||||
"""创建事件发射器"""
|
||||
|
|
|
|||
|
|
@ -72,6 +72,10 @@ class AgentRunner:
|
|||
self.event_manager = EventManager(db_session_factory=async_session_factory)
|
||||
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
|
||||
|
||||
# 🔥 CRITICAL: 立即创建事件队列,确保在 Agent 开始执行前队列就存在
|
||||
# 这样即使前端 SSE 连接稍晚,token 事件也不会丢失
|
||||
self.event_manager.create_queue(task.id)
|
||||
|
||||
# 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取)
|
||||
self.llm_service = LLMService(user_config=self.user_config)
|
||||
|
||||
|
|
@ -708,6 +712,11 @@ class AgentRunner:
|
|||
|
||||
for finding in findings:
|
||||
try:
|
||||
# 确保 finding 是字典
|
||||
if not isinstance(finding, dict):
|
||||
logger.warning(f"Skipping invalid finding (not a dict): {finding}")
|
||||
continue
|
||||
|
||||
db_finding = AgentFinding(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -12,7 +12,7 @@ import {
|
|||
AgentStreamState,
|
||||
} from '../shared/api/agentStream';
|
||||
|
||||
export interface UseAgentStreamOptions extends Omit<StreamOptions, 'onEvent'> {
|
||||
export interface UseAgentStreamOptions extends StreamOptions {
|
||||
autoConnect?: boolean;
|
||||
maxEvents?: number;
|
||||
}
|
||||
|
|
@ -73,6 +73,10 @@ export function useAgentStream(
|
|||
...callbackOptions
|
||||
} = options;
|
||||
|
||||
// 🔥 使用 ref 存储 callback options,避免 connect 函数依赖变化导致重连
|
||||
const callbackOptionsRef = useRef(callbackOptions);
|
||||
callbackOptionsRef.current = callbackOptions;
|
||||
|
||||
// 状态
|
||||
const [events, setEvents] = useState<StreamEventData[]>([]);
|
||||
const [thinking, setThinking] = useState('');
|
||||
|
|
@ -115,8 +119,17 @@ export function useAgentStream(
|
|||
includeThinking,
|
||||
includeToolCalls,
|
||||
afterSequence,
|
||||
|
||||
|
||||
onEvent: (event) => {
|
||||
// Pass to custom callback first (important for capturing metadata like agent_name)
|
||||
callbackOptionsRef.current.onEvent?.(event);
|
||||
|
||||
// 忽略 thinking 事件,防止污染日志列表 (它们会通过 onThinking* 回调单独处理)
|
||||
if (
|
||||
event.type === 'thinking_token' ||
|
||||
event.type === 'thinking_start' ||
|
||||
event.type === 'thinking_end'
|
||||
) return;
|
||||
setEvents((prev) => [...prev.slice(-maxEvents + 1), event]);
|
||||
},
|
||||
|
||||
|
|
@ -124,20 +137,20 @@ export function useAgentStream(
|
|||
thinkingBufferRef.current = [];
|
||||
setIsThinking(true);
|
||||
setThinking('');
|
||||
callbackOptions.onThinkingStart?.();
|
||||
callbackOptionsRef.current.onThinkingStart?.();
|
||||
},
|
||||
|
||||
onThinkingToken: (token, accumulated) => {
|
||||
thinkingBufferRef.current.push(token);
|
||||
setThinking(accumulated);
|
||||
callbackOptions.onThinkingToken?.(token, accumulated);
|
||||
callbackOptionsRef.current.onThinkingToken?.(token, accumulated);
|
||||
},
|
||||
|
||||
onThinkingEnd: (response) => {
|
||||
setIsThinking(false);
|
||||
setThinking(response);
|
||||
thinkingBufferRef.current = [];
|
||||
callbackOptions.onThinkingEnd?.(response);
|
||||
callbackOptionsRef.current.onThinkingEnd?.(response);
|
||||
},
|
||||
|
||||
onToolStart: (name, input) => {
|
||||
|
|
@ -145,7 +158,7 @@ export function useAgentStream(
|
|||
...prev,
|
||||
{ name, input, status: 'running' as const },
|
||||
]);
|
||||
callbackOptions.onToolStart?.(name, input);
|
||||
callbackOptionsRef.current.onToolStart?.(name, input);
|
||||
},
|
||||
|
||||
onToolEnd: (name, output, durationMs) => {
|
||||
|
|
@ -156,16 +169,16 @@ export function useAgentStream(
|
|||
: tc
|
||||
)
|
||||
);
|
||||
callbackOptions.onToolEnd?.(name, output, durationMs);
|
||||
callbackOptionsRef.current.onToolEnd?.(name, output, durationMs);
|
||||
},
|
||||
|
||||
onNodeStart: (nodeName, phase) => {
|
||||
setCurrentPhase(phase);
|
||||
callbackOptions.onNodeStart?.(nodeName, phase);
|
||||
callbackOptionsRef.current.onNodeStart?.(nodeName, phase);
|
||||
},
|
||||
|
||||
onNodeEnd: (nodeName, summary) => {
|
||||
callbackOptions.onNodeEnd?.(nodeName, summary);
|
||||
callbackOptionsRef.current.onNodeEnd?.(nodeName, summary);
|
||||
},
|
||||
|
||||
onProgress: (current, total, message) => {
|
||||
|
|
@ -174,35 +187,35 @@ export function useAgentStream(
|
|||
total,
|
||||
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
|
||||
});
|
||||
callbackOptions.onProgress?.(current, total, message);
|
||||
callbackOptionsRef.current.onProgress?.(current, total, message);
|
||||
},
|
||||
|
||||
onFinding: (finding, isVerified) => {
|
||||
setFindings((prev) => [...prev, finding]);
|
||||
callbackOptions.onFinding?.(finding, isVerified);
|
||||
callbackOptionsRef.current.onFinding?.(finding, isVerified);
|
||||
},
|
||||
|
||||
onComplete: (data) => {
|
||||
setIsComplete(true);
|
||||
setIsConnected(false);
|
||||
callbackOptions.onComplete?.(data);
|
||||
callbackOptionsRef.current.onComplete?.(data);
|
||||
},
|
||||
|
||||
onError: (err) => {
|
||||
setError(err);
|
||||
setIsComplete(true);
|
||||
setIsConnected(false);
|
||||
callbackOptions.onError?.(err);
|
||||
callbackOptionsRef.current.onError?.(err);
|
||||
},
|
||||
|
||||
onHeartbeat: () => {
|
||||
callbackOptions.onHeartbeat?.();
|
||||
callbackOptionsRef.current.onHeartbeat?.();
|
||||
},
|
||||
});
|
||||
|
||||
handlerRef.current.connect();
|
||||
setIsConnected(true);
|
||||
}, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents, callbackOptions]);
|
||||
}, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents]); // 🔥 移除 callbackOptions 依赖
|
||||
|
||||
// 断开连接
|
||||
const disconnect = useCallback(() => {
|
||||
|
|
@ -261,7 +274,7 @@ export function useAgentThinking(taskId: string | null) {
|
|||
const { thinking, isThinking, connect, disconnect } = useAgentStream(taskId, {
|
||||
includeToolCalls: false,
|
||||
});
|
||||
|
||||
|
||||
return { thinking, isThinking, connect, disconnect };
|
||||
}
|
||||
|
||||
|
|
@ -272,9 +285,8 @@ export function useAgentToolCalls(taskId: string | null) {
|
|||
const { toolCalls, connect, disconnect } = useAgentStream(taskId, {
|
||||
includeThinking: false,
|
||||
});
|
||||
|
||||
|
||||
return { toolCalls, connect, disconnect };
|
||||
}
|
||||
|
||||
export default useAgentStream;
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -10,6 +10,7 @@
|
|||
// 事件类型定义
|
||||
export type StreamEventType =
|
||||
// LLM 相关
|
||||
| 'thinking' // General thinking event
|
||||
| 'thinking_start'
|
||||
| 'thinking_token'
|
||||
| 'thinking_end'
|
||||
|
|
@ -19,6 +20,8 @@ export type StreamEventType =
|
|||
| 'tool_call_output'
|
||||
| 'tool_call_end'
|
||||
| 'tool_call_error'
|
||||
| 'tool_call' // Backend sends this
|
||||
| 'tool_result' // Backend sends this
|
||||
// 节点相关
|
||||
| 'node_start'
|
||||
| 'node_end'
|
||||
|
|
@ -69,6 +72,12 @@ export interface StreamEventData {
|
|||
error?: string; // task_error
|
||||
findings_count?: number; // task_complete
|
||||
security_score?: number; // task_complete
|
||||
// Backend tool event fields
|
||||
tool_name?: string; // tool_call, tool_result
|
||||
tool_input?: Record<string, unknown>; // tool_call
|
||||
tool_output?: unknown; // tool_result
|
||||
tool_duration_ms?: number; // tool_result
|
||||
agent_name?: string; // Extracted from metadata
|
||||
}
|
||||
|
||||
// 事件回调类型
|
||||
|
|
@ -191,19 +200,24 @@ export class AgentStreamHandler {
|
|||
}
|
||||
|
||||
const { done, value } = await this.reader.read();
|
||||
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
|
||||
// 解析 SSE 事件
|
||||
const events = this.parseSSE(buffer);
|
||||
buffer = events.remaining;
|
||||
|
||||
|
||||
// 🔥 逐个处理事件,添加微延迟确保 React 能逐个渲染
|
||||
for (const event of events.parsed) {
|
||||
this.handleEvent(event);
|
||||
// 为 thinking_token 添加微延迟确保打字效果
|
||||
if (event.type === 'thinking_token') {
|
||||
await new Promise(resolve => setTimeout(resolve, 5));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -220,7 +234,7 @@ export class AgentStreamHandler {
|
|||
|
||||
this.isConnected = false;
|
||||
console.error('Stream connection error:', error);
|
||||
|
||||
|
||||
// 🔥 只有在未断开时才尝试重连
|
||||
if (!this.isDisconnecting && this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
|
|
@ -253,10 +267,10 @@ export class AgentStreamHandler {
|
|||
const lines = buffer.split('\n');
|
||||
let remaining = '';
|
||||
let currentEvent: Partial<StreamEventData> = {};
|
||||
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i];
|
||||
|
||||
|
||||
// 空行表示事件结束
|
||||
if (line === '') {
|
||||
if (currentEvent.type) {
|
||||
|
|
@ -265,13 +279,13 @@ export class AgentStreamHandler {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// 检查是否是最后一行(可能不完整)
|
||||
if (i === lines.length - 1 && !buffer.endsWith('\n')) {
|
||||
remaining = line;
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// 解析 event: 行
|
||||
if (line.startsWith('event:')) {
|
||||
currentEvent.type = line.slice(6).trim() as StreamEventType;
|
||||
|
|
@ -286,7 +300,7 @@ export class AgentStreamHandler {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return { parsed, remaining };
|
||||
}
|
||||
|
||||
|
|
@ -294,6 +308,11 @@ export class AgentStreamHandler {
|
|||
* 处理事件
|
||||
*/
|
||||
private handleEvent(event: StreamEventData): void {
|
||||
// Extract agent_name from metadata if present
|
||||
if (event.metadata?.agent_name && !event.agent_name) {
|
||||
event.agent_name = event.metadata.agent_name as string;
|
||||
}
|
||||
|
||||
// 通用回调
|
||||
this.options.onEvent?.(event);
|
||||
|
||||
|
|
@ -304,19 +323,23 @@ export class AgentStreamHandler {
|
|||
this.thinkingBuffer = [];
|
||||
this.options.onThinkingStart?.();
|
||||
break;
|
||||
|
||||
|
||||
case 'thinking_token':
|
||||
if (event.token) {
|
||||
this.thinkingBuffer.push(event.token);
|
||||
// 兼容处理:token 可能在顶层,也可能在 metadata 中
|
||||
const token = event.token || (event.metadata?.token as string);
|
||||
const accumulated = event.accumulated || (event.metadata?.accumulated as string);
|
||||
|
||||
if (token) {
|
||||
this.thinkingBuffer.push(token);
|
||||
this.options.onThinkingToken?.(
|
||||
event.token,
|
||||
event.accumulated || this.thinkingBuffer.join('')
|
||||
token,
|
||||
accumulated || this.thinkingBuffer.join('')
|
||||
);
|
||||
}
|
||||
break;
|
||||
|
||||
|
||||
case 'thinking_end':
|
||||
const fullResponse = event.accumulated || this.thinkingBuffer.join('');
|
||||
const fullResponse = event.accumulated || (event.metadata?.accumulated as string) || this.thinkingBuffer.join('');
|
||||
this.thinkingBuffer = [];
|
||||
this.options.onThinkingEnd?.(fullResponse);
|
||||
break;
|
||||
|
|
@ -330,7 +353,7 @@ export class AgentStreamHandler {
|
|||
);
|
||||
}
|
||||
break;
|
||||
|
||||
|
||||
case 'tool_call_end':
|
||||
if (event.tool) {
|
||||
this.options.onToolEnd?.(
|
||||
|
|
@ -341,6 +364,22 @@ export class AgentStreamHandler {
|
|||
}
|
||||
break;
|
||||
|
||||
// Alternative event names (backend sends these)
|
||||
case 'tool_call':
|
||||
this.options.onToolStart?.(
|
||||
event.tool_name || 'unknown',
|
||||
event.tool_input || {}
|
||||
);
|
||||
break;
|
||||
|
||||
case 'tool_result':
|
||||
this.options.onToolEnd?.(
|
||||
event.tool_name || 'unknown',
|
||||
event.tool_output,
|
||||
event.tool_duration_ms || 0
|
||||
);
|
||||
break;
|
||||
|
||||
// 节点
|
||||
case 'node_start':
|
||||
this.options.onNodeStart?.(
|
||||
|
|
@ -348,7 +387,7 @@ export class AgentStreamHandler {
|
|||
event.phase || ''
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case 'node_end':
|
||||
this.options.onNodeEnd?.(
|
||||
event.metadata?.node as string || 'unknown',
|
||||
|
|
@ -407,13 +446,13 @@ export class AgentStreamHandler {
|
|||
// 🔥 标记正在断开,防止重连
|
||||
this.isDisconnecting = true;
|
||||
this.isConnected = false;
|
||||
|
||||
|
||||
// 🔥 取消 fetch 请求
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
}
|
||||
|
||||
|
||||
// 🔥 清理 reader
|
||||
if (this.reader) {
|
||||
try {
|
||||
|
|
@ -424,13 +463,13 @@ export class AgentStreamHandler {
|
|||
}
|
||||
this.reader = null;
|
||||
}
|
||||
|
||||
|
||||
// 清理 EventSource(如果使用)
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
|
||||
|
||||
// 重置重连计数
|
||||
this.reconnectAttempts = 0;
|
||||
}
|
||||
|
|
@ -469,6 +508,7 @@ export interface AgentStreamState {
|
|||
events: StreamEventData[];
|
||||
thinking: string;
|
||||
isThinking: boolean;
|
||||
thinkingAgent?: string; // Who is thinking
|
||||
toolCalls: Array<{
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
|
|
@ -494,6 +534,7 @@ export function createAgentStreamWithState(
|
|||
events: [],
|
||||
thinking: '',
|
||||
isThinking: false,
|
||||
thinkingAgent: undefined,
|
||||
toolCalls: [],
|
||||
currentPhase: '',
|
||||
progress: { current: 0, total: 100, percentage: 0 },
|
||||
|
|
@ -509,9 +550,16 @@ export function createAgentStreamWithState(
|
|||
|
||||
return new AgentStreamHandler(taskId, {
|
||||
onEvent: (event) => {
|
||||
updateState({
|
||||
events: [...state.events, event].slice(-500), // 保留最近 500 条
|
||||
});
|
||||
const updates: Partial<AgentStreamState> = {
|
||||
events: [...state.events, event].slice(-500),
|
||||
};
|
||||
|
||||
// Update thinking agent if available
|
||||
if (event.agent_name && (event.type === 'thinking' || event.type === 'thinking_start' || event.type === 'thinking_token')) {
|
||||
updates.thinkingAgent = event.agent_name;
|
||||
}
|
||||
|
||||
updateState(updates);
|
||||
},
|
||||
onThinkingStart: () => {
|
||||
updateState({ isThinking: true, thinking: '' });
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue