CodeReview/backend/app/services/agent/core/context.py

358 lines
11 KiB
Python

"""
Execution Context Module
Provides distributed tracing and correlation ID management for the Agent framework.
Enables tracking of requests across agents, tools, and services.
"""
import contextvars
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import uuid4
# ============ Context Variables ============
# Global context variable for correlation ID
_correlation_id: contextvars.ContextVar[str] = contextvars.ContextVar(
'correlation_id',
default=''
)
# Global context variable for task ID
_task_id: contextvars.ContextVar[str] = contextvars.ContextVar(
'task_id',
default=''
)
# Global context variable for current agent
_current_agent: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
'current_agent',
default=None
)
# Global context variable for trace path
_trace_path: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
'trace_path',
default=[]
)
# ============ Context Accessors ============
def get_correlation_id() -> str:
"""Get the current correlation ID, generating one if not set"""
cid = _correlation_id.get()
if not cid:
cid = generate_correlation_id()
_correlation_id.set(cid)
return cid
def set_correlation_id(cid: str) -> contextvars.Token:
"""Set the correlation ID and return a token for resetting"""
return _correlation_id.set(cid)
def get_task_id() -> str:
"""Get the current task ID"""
return _task_id.get()
def set_task_id(task_id: str) -> contextvars.Token:
"""Set the task ID and return a token for resetting"""
return _task_id.set(task_id)
def get_current_agent() -> Optional[str]:
"""Get the current agent name"""
return _current_agent.get()
def set_current_agent(agent_name: str) -> contextvars.Token:
"""Set the current agent name"""
return _current_agent.set(agent_name)
def get_trace_path() -> List[str]:
"""Get the current trace path (list of agent names)"""
return _trace_path.get().copy()
def push_trace(agent_name: str) -> None:
"""Add an agent to the trace path"""
current = _trace_path.get()
_trace_path.set([*current, agent_name])
def pop_trace() -> Optional[str]:
"""Remove the last agent from the trace path"""
current = _trace_path.get()
if current:
_trace_path.set(current[:-1])
return current[-1]
return None
def generate_correlation_id() -> str:
"""Generate a new correlation ID"""
return f"cid-{uuid4().hex[:12]}"
# ============ Execution Context ============
@dataclass
class ExecutionContext:
"""
Execution context for tracking requests across the agent system.
This context is passed down through agent calls and tool executions
to enable distributed tracing and debugging.
"""
correlation_id: str = field(default_factory=generate_correlation_id)
task_id: str = ""
parent_agent_id: Optional[str] = None
current_agent_id: Optional[str] = None
current_agent_name: Optional[str] = None
trace_path: List[str] = field(default_factory=list)
iteration: int = 0
depth: int = 0
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
metadata: Dict[str, Any] = field(default_factory=dict)
def child_context(
self,
agent_id: str,
agent_name: str,
) -> "ExecutionContext":
"""
Create a child context for a sub-agent.
Args:
agent_id: ID of the child agent
agent_name: Name of the child agent
Returns:
New ExecutionContext for the child agent
"""
return ExecutionContext(
correlation_id=self.correlation_id,
task_id=self.task_id,
parent_agent_id=self.current_agent_id,
current_agent_id=agent_id,
current_agent_name=agent_name,
trace_path=[*self.trace_path, agent_name],
iteration=0,
depth=self.depth + 1,
metadata=self.metadata.copy(),
)
def with_iteration(self, iteration: int) -> "ExecutionContext":
"""Create a copy with updated iteration"""
ctx = ExecutionContext(
correlation_id=self.correlation_id,
task_id=self.task_id,
parent_agent_id=self.parent_agent_id,
current_agent_id=self.current_agent_id,
current_agent_name=self.current_agent_name,
trace_path=self.trace_path.copy(),
iteration=iteration,
depth=self.depth,
created_at=self.created_at,
metadata=self.metadata.copy(),
)
return ctx
def with_metadata(self, **kwargs) -> "ExecutionContext":
"""Create a copy with additional metadata"""
new_metadata = {**self.metadata, **kwargs}
return ExecutionContext(
correlation_id=self.correlation_id,
task_id=self.task_id,
parent_agent_id=self.parent_agent_id,
current_agent_id=self.current_agent_id,
current_agent_name=self.current_agent_name,
trace_path=self.trace_path.copy(),
iteration=self.iteration,
depth=self.depth,
created_at=self.created_at,
metadata=new_metadata,
)
@property
def trace_string(self) -> str:
"""Get trace path as a string (e.g., 'orchestrator > analysis > verification')"""
return " > ".join(self.trace_path) if self.trace_path else "root"
@property
def span_id(self) -> str:
"""Get a unique span ID for this context"""
agent = self.current_agent_id or "unknown"
return f"{self.correlation_id}:{agent}:{self.iteration}"
def to_dict(self) -> Dict[str, Any]:
"""Convert context to dictionary for serialization"""
return {
"correlation_id": self.correlation_id,
"task_id": self.task_id,
"parent_agent_id": self.parent_agent_id,
"current_agent_id": self.current_agent_id,
"current_agent_name": self.current_agent_name,
"trace_path": self.trace_path,
"trace_string": self.trace_string,
"iteration": self.iteration,
"depth": self.depth,
"created_at": self.created_at,
"metadata": self.metadata,
}
def to_log_dict(self) -> Dict[str, Any]:
"""Get minimal context for logging"""
return {
"correlation_id": self.correlation_id,
"task_id": self.task_id,
"agent_id": self.current_agent_id,
"agent_name": self.current_agent_name,
"trace": self.trace_string,
"iteration": self.iteration,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionContext":
"""Create context from dictionary"""
return cls(
correlation_id=data.get("correlation_id", generate_correlation_id()),
task_id=data.get("task_id", ""),
parent_agent_id=data.get("parent_agent_id"),
current_agent_id=data.get("current_agent_id"),
current_agent_name=data.get("current_agent_name"),
trace_path=data.get("trace_path", []),
iteration=data.get("iteration", 0),
depth=data.get("depth", 0),
created_at=data.get("created_at", datetime.utcnow().isoformat()),
metadata=data.get("metadata", {}),
)
# ============ Context Manager ============
class ExecutionContextManager:
"""
Context manager for managing execution context.
Usage:
async with ExecutionContextManager(context) as ctx:
# Context variables are set for this scope
await do_something()
"""
def __init__(self, context: ExecutionContext):
self.context = context
self._tokens: List[contextvars.Token] = []
def __enter__(self) -> ExecutionContext:
"""Enter context and set context variables"""
self._tokens.append(_correlation_id.set(self.context.correlation_id))
self._tokens.append(_task_id.set(self.context.task_id))
if self.context.current_agent_name:
self._tokens.append(_current_agent.set(self.context.current_agent_name))
self._tokens.append(_trace_path.set(self.context.trace_path))
return self.context
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context and restore previous values"""
for token in reversed(self._tokens):
try:
token.var.reset(token)
except ValueError:
pass # Token was already reset
self._tokens.clear()
return False
async def __aenter__(self) -> ExecutionContext:
"""Async enter context"""
return self.__enter__()
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async exit context"""
return self.__exit__(exc_type, exc_val, exc_tb)
def create_context(
task_id: str,
correlation_id: Optional[str] = None,
**metadata
) -> ExecutionContext:
"""
Create a new execution context for a task.
Args:
task_id: The task ID
correlation_id: Optional correlation ID (generated if not provided)
**metadata: Additional metadata to include
Returns:
New ExecutionContext
"""
return ExecutionContext(
correlation_id=correlation_id or generate_correlation_id(),
task_id=task_id,
metadata=metadata,
)
def get_current_context() -> ExecutionContext:
"""
Get the current execution context from context variables.
Returns a context with current values from context variables.
"""
return ExecutionContext(
correlation_id=get_correlation_id(),
task_id=get_task_id(),
current_agent_name=get_current_agent(),
trace_path=get_trace_path(),
)
# ============ Decorators ============
def with_context(context: ExecutionContext):
"""
Decorator to run a function with an execution context.
Usage:
@with_context(my_context)
async def my_function():
# Context variables are set
pass
"""
def decorator(func):
async def wrapper(*args, **kwargs):
with ExecutionContextManager(context):
return await func(*args, **kwargs)
return wrapper
return decorator
def traced(agent_name: str):
"""
Decorator to add an agent to the trace path.
Usage:
@traced("analysis")
async def run_analysis():
# Trace path includes "analysis"
pass
"""
def decorator(func):
async def wrapper(*args, **kwargs):
push_trace(agent_name)
try:
return await func(*args, **kwargs)
finally:
pop_trace()
return wrapper
return decorator