Using Callbacks ============== Callbacks allow you to hook into the agent's execution flow, enabling you to log, monitor, or modify the agent's behavior. This section provides an extensive guide to understanding and using callbacks effectively. Callback System Architecture -------------------------- The callback system in Malevich Brain is designed around an event-based architecture: 1. **Events**: Specific points in the agent's execution flow that trigger callbacks 2. **Event Types**: Categories and subcategories of events (e.g., "message.user") 3. **Callback Functions**: User-defined functions that respond to specific events 4. **CallbackStream**: Special objects for handling streaming events When an event occurs during agent execution, all registered callbacks that match the event pattern are triggered. Callbacks can be synchronous or asynchronous. Creating Callbacks ---------------- You can create callbacks using the ``@callback`` decorator: .. code-block:: python from brain.agents.callback import callback @callback("message.user") async def log_user_message(agent, event, data): """Log user messages""" print(f"User said: {data.content}") @callback("message.assistant") async def log_assistant_message(agent, event, data): """Log assistant messages""" print(f"Assistant said: {data.content}") The first parameter to the ``@callback`` decorator is a pattern that specifies which events the callback should respond to. Callback Function Parameters -------------------------- Each callback function receives three parameters: 1. **agent**: The Agent instance that triggered the callback 2. **event**: A string representing the event type (e.g., "message.user") 3. **data**: For regular callbacks, the data associated with the event (e.g., a Message object) OR **stream**: For streaming callbacks, a CallbackStream object For regular event callbacks: .. code-block:: python @callback("message.user") async def process_user_message(agent, event, data): # data is a Message object print(f"Message content: {data.content}") print(f"Message role: {data.role}") For streaming event callbacks: .. code-block:: python @callback("message_stream.assistant") async def process_stream(agent, event, stream): # stream is a CallbackStream object async for chunk in stream: # Process each chunk pass Event Types ---------- Callbacks can respond to the following event types: - ``message.user`` - Triggered when a user message is received - ``message.assistant`` - Triggered when the assistant sends a message - ``message.system`` - Triggered when a system message is used - ``message_stream.assistant`` - Triggered when the assistant starts streaming a message You can use wildcards in event patterns: - ``message.*`` - Matches all message events - ``*.assistant`` - Matches all assistant events Event Matching Rules ------------------ The callback system uses pattern matching to determine which callbacks to trigger for a given event: 1. Exact matches: "message.user" matches only user message events 2. Wildcard matches: "message.*" matches all message events 3. Multiple wildcards: "*.assistant" matches all assistant events Internally, the system parses the event string into a tuple of components and compares them with the callback's pattern: .. code-block:: python # Examples of event parsing "message.user" -> ("message", "user") "message.*" -> ("message", "*") "*.assistant" -> ("*", "assistant") Using Callbacks with Agents ------------------------- To use callbacks with an agent, pass them to the agent's constructor: .. code-block:: python from brain.agents.agent import Agent # Create the agent with callbacks agent = Agent( llm=llm, tools=[], callbacks=[ log_user_message, log_assistant_message ] ) You can register multiple callbacks for the same event, and they will all be executed in the order they were registered. Synchronous vs. Asynchronous Callbacks ------------------------------------ Callbacks can be defined as either synchronous or asynchronous functions: .. code-block:: python # Synchronous callback @callback("message.user") def sync_callback(agent, event, data): print(f"Sync: {data.content}") # Asynchronous callback @callback("message.user") async def async_callback(agent, event, data): await asyncio.sleep(0.1) # Simulate async operation print(f"Async: {data.content}") The Malevich Brain agent will properly handle both types, automatically awaiting async callbacks when needed. However, for best performance with I/O-bound operations (like database access or network calls), asynchronous callbacks are recommended. Streaming Content with Callbacks ------------------------------ The ``message_stream.assistant`` event provides streaming content: .. code-block:: python @callback("message_stream.assistant") async def stream_handler(agent, event, stream): """Handle streaming content""" print("Assistant: ", end="", flush=True) async for chunk in stream: if hasattr(chunk, "chunk"): # Print each chunk as it arrives print(chunk.chunk, end="", flush=True) elif hasattr(chunk, "content") and chunk.content: # Handle complete messages print(f"\nComplete message: {chunk.content}") The stream parameter in this callback is a ``CallbackStream`` object that yields either: 1. ``MessageChunk`` objects: Small pieces of the message being generated 2. ``Message`` objects: Complete messages Working with the CallbackStream ----------------------------- The ``CallbackStream`` is an async iterator that provides content as it's generated. It allows real-time processing of LLM outputs. .. code-block:: python @callback("message_stream.assistant") async def detailed_stream_handler(agent, event, stream): """Detailed handling of streaming content""" buffer = "" # Iterate through the stream async for item in stream: # Different types in the stream if hasattr(item, "chunk"): # This is a MessageChunk chunk_text = item.chunk buffer += chunk_text # Process each chunk (e.g., print to console) print(chunk_text, end="", flush=True) # Example: Analyze content as it streams if "error" in chunk_text.lower(): print("\n[Alert: Error mentioned!]") elif hasattr(item, "content") and hasattr(item, "role"): # This is a complete Message print(f"\n[Complete message received: {len(item.content)} chars]") # Stream is complete print(f"\n[Stream finished, total length: {len(buffer)}]") The stream automatically terminates when the entire response has been generated. The ``.done()`` method is called internally. Practical Use Cases ----------------- 1. **Logging to a Database**: .. code-block:: python import asyncpg from datetime import datetime async def setup_db_pool(): return await asyncpg.create_pool( "postgresql://user:password@localhost/dbname" ) # Shared connection pool db_pool = asyncio.run(setup_db_pool()) @callback("message.*") async def db_logger(agent, event, data): """Log all messages to a database""" async with db_pool.acquire() as conn: await conn.execute( """ INSERT INTO message_logs(role, content, event_type, timestamp) VALUES($1, $2, $3, $4) """, data.role, data.content, event, datetime.now() ) 2. **Real-time UI Updates**: .. code-block:: python import websockets # Active WebSocket connections active_connections = set() @callback("message_stream.assistant") async def update_ui(agent, event, stream): """Send streaming updates to a UI via WebSocket""" # Send the start message to all connections message = {"type": "stream_start"} await broadcast(message) # Process each chunk async for chunk in stream: if hasattr(chunk, "chunk"): # Send each chunk to connected clients message = { "type": "stream_chunk", "content": chunk.chunk } await broadcast(message) # Send the end message message = {"type": "stream_end"} await broadcast(message) async def broadcast(message): """Send a message to all connected WebSocket clients""" if active_connections: await asyncio.gather(*[ connection.send_json(message) for connection in active_connections ]) 3. **Monitoring Tool Usage**: .. code-block:: python import time # Track tool usage metrics tool_metrics = { "calls": {}, "errors": {}, "execution_time": {} } @callback("message.user") async def monitor_tool_calls(agent, event, data): """Monitor tool usage""" # Only process tool response messages if hasattr(data, "tool_call_id") and hasattr(data, "error_message"): tool_name = data.content.split(" ")[1] # Extract tool name from content # Initialize counters if needed if tool_name not in tool_metrics["calls"]: tool_metrics["calls"][tool_name] = 0 tool_metrics["errors"][tool_name] = 0 tool_metrics["execution_time"][tool_name] = [] # Update metrics tool_metrics["calls"][tool_name] += 1 # Track errors if data.error_message: tool_metrics["errors"][tool_name] += 1 # Report metrics periodically @callback("message.assistant") async def log_metrics_periodically(agent, event, data): # Every 10 assistant messages, log metrics if sum(tool_metrics["calls"].values()) % 10 == 0 and sum(tool_metrics["calls"].values()) > 0: print("\n--- Tool Usage Metrics ---") for tool_name, call_count in tool_metrics["calls"].items(): error_rate = tool_metrics["errors"][tool_name] / call_count if call_count > 0 else 0 print(f"{tool_name}: {call_count} calls, {error_rate:.2%} error rate") print("-------------------------\n") 4. **Token Counting and Cost Tracking**: .. code-block:: python import tiktoken class TokenCounter: def __init__(self, model_name="gpt-4o"): self.model = model_name self.encoder = tiktoken.encoding_for_model(model_name) self.total_prompt_tokens = 0 self.total_completion_tokens = 0 def count_tokens(self, text): return len(self.encoder.encode(text)) # Initialize counter token_counter = TokenCounter() @callback("message.user") async def count_prompt_tokens(agent, event, data): """Count tokens in user messages""" tokens = token_counter.count_tokens(data.content) token_counter.total_prompt_tokens += tokens @callback("message.assistant") async def count_completion_tokens(agent, event, data): """Count tokens in assistant responses""" tokens = token_counter.count_tokens(data.content) token_counter.total_completion_tokens += tokens # Report totals occasionally if token_counter.total_completion_tokens % 1000 < tokens: # Estimate cost (based on approximate OpenAI pricing) prompt_cost = token_counter.total_prompt_tokens * 0.00001 # $0.01 per 1K tokens completion_cost = token_counter.total_completion_tokens * 0.00003 # $0.03 per 1K tokens total_cost = prompt_cost + completion_cost print(f"\nToken usage: {token_counter.total_prompt_tokens} prompt, " + f"{token_counter.total_completion_tokens} completion") print(f"Estimated cost: ${total_cost:.4f}") Advanced Callback Patterns ------------------------ 1. **Conversation History Tracking**: .. code-block:: python # Track conversation state conversation_history = [] @callback("message.*") async def track_conversation(agent, event, data): """Add messages to conversation history""" conversation_history.append({ "role": data.role, "content": data.content, "timestamp": datetime.now().isoformat() }) # Limit history size (keep last 100 messages) if len(conversation_history) > 100: conversation_history.pop(0) # Save conversation periodically @callback("message.user") async def save_conversation(agent, event, data): # Save every 10 messages if len(conversation_history) % 10 == 0: with open(f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", "w") as f: json.dump(conversation_history, f, indent=2) 2. **Tool Call Lifecycle Monitoring**: .. code-block:: python # Track active tool calls active_tool_calls = {} @callback("message.assistant") async def track_tool_call_start(agent, event, data): """Track when a tool call is initiated""" if hasattr(data, "tool_call_id") and hasattr(data, "tool_name"): active_tool_calls[data.tool_call_id] = { "tool_name": data.tool_name, "start_time": time.time(), "status": "started" } @callback("message.user") async def track_tool_call_end(agent, event, data): """Track when a tool call is completed""" if hasattr(data, "tool_call_id"): call_id = data.tool_call_id if call_id in active_tool_calls: active_tool_calls[call_id]["end_time"] = time.time() active_tool_calls[call_id]["status"] = "completed" # Calculate duration duration = active_tool_calls[call_id]["end_time"] - active_tool_calls[call_id]["start_time"] active_tool_calls[call_id]["duration"] = duration if hasattr(data, "error_message") and data.error_message: active_tool_calls[call_id]["status"] = "error" active_tool_calls[call_id]["error"] = data.error_message # Log completion tool_name = active_tool_calls[call_id]["tool_name"] status = active_tool_calls[call_id]["status"] print(f"Tool {tool_name} {status} in {duration:.4f}s") 3. **Debug Logging Decorator**: .. code-block:: python def with_debug_logging(event_pattern): """Decorator to add debug logging to a callback""" def decorator(callback_func): # Create the callback as normal cb = callback(event_pattern)(callback_func) # Create a wrapper around the original route method original_route = cb.route async def route_with_logging(agent, event, data): print(f"DEBUG: Callback triggered for event {event}") try: # Call the original route method result = await original_route(agent, event, data) print(f"DEBUG: Callback completed successfully") return result except Exception as e: print(f"DEBUG: Callback failed with error: {e}") raise # Replace the route method cb.route = route_with_logging return cb return decorator # Usage @with_debug_logging("message.user") async def my_callback(agent, event, data): print(f"Processing message: {data.content}") 4. **Chaining Callbacks**: .. code-block:: python class CallbackChain: """Chain multiple callbacks together""" def __init__(self, event_pattern, callbacks): self.event_pattern = event_pattern self.callbacks = callbacks def create(self): """Create a single callback that runs all callbacks in the chain""" @callback(self.event_pattern) async def chained_callback(agent, event, data): for cb_func in self.callbacks: result = cb_func(agent, event, data) if result is not None and asyncio.iscoroutine(result): await result return chained_callback # Usage def log_to_console(agent, event, data): print(f"CONSOLE: {data.content}") async def log_to_database(agent, event, data): print(f"DATABASE: Saving {data.content}") await asyncio.sleep(0.1) # Simulate database operation # Create a chain message_chain = CallbackChain("message.user", [ log_to_console, log_to_database ]).create() # Use in an agent agent = Agent(llm=llm, tools=[], callbacks=[message_chain]) Best Practices ------------ 1. **Performance Considerations**: - Keep callbacks lightweight and fast to avoid slowing down the agent - Use asynchronous callbacks for I/O operations - Avoid blocking operations in synchronous callbacks - Be mindful of memory usage in callbacks that store data 2. **Error Handling**: - Implement proper error handling in callbacks to prevent them from crashing the agent - Use try/except blocks to catch and log errors .. code-block:: python @callback("message.user") async def safe_callback(agent, event, data): try: # Potentially risky operation result = await external_api_call(data.content) print(f"API result: {result}") except Exception as e: print(f"Callback error (non-fatal): {e}") # Continue execution without disrupting the agent 3. **Resource Management**: - Initialize expensive resources (database connections, API clients) outside callbacks - Use connection pooling for database connections - Close resources properly when they're no longer needed 4. **Debugging Callbacks**: - Add logging to track callback execution - Use unique identifiers for messages and tool calls to trace execution flow - Consider using the debug logging decorator pattern shown above 5. **Organization**: - Group related callbacks together in modules - Use descriptive names for callback functions - Document the purpose and behavior of each callback .. code-block:: python # In callbacks/logging.py @callback("message.*") async def log_all_messages(agent, event, data): """Log all messages to console with timestamp""" # In callbacks/metrics.py @callback("message.user") async def track_tool_usage(agent, event, data): """Track and record tool usage metrics""" # In main.py from callbacks.logging import log_all_messages from callbacks.metrics import track_tool_usage agent = Agent( llm=llm, tools=tools, callbacks=[log_all_messages, track_tool_usage] ) Callback Implementation Internals ------------------------------- Understanding how callbacks are implemented internally can help you use them more effectively: 1. **Callback Registration**: When you create an agent with callbacks, they are stored in the agent's ``callbacks`` list: .. code-block:: python # In Agent.__init__ self.callbacks = callbacks 2. **Event Triggering**: The agent triggers events at specific points using the ``execute_callback`` method: .. code-block:: python # In Agent.run await self.execute_callback('message.user', message) # ... later ... await self.execute_callback('message.assistant', message) 3. **Callback Execution**: The ``execute_callback`` method routes events to matching callbacks: .. code-block:: python async def execute_callback(self, event: str, data: Any) -> None: await asyncio.gather(*[ callback.route(self, event, data) for callback in self.callbacks ]) This gathers all the callback coroutines and runs them concurrently with ``asyncio.gather``. 4. **Streaming Implementation**: For streaming events, a special method ``execute_stream_callback`` creates ``CallbackStream`` objects: .. code-block:: python async def execute_stream_callback(self, event: str) -> list[tuple[CallbackStream, asyncio.Task[None]]]: cbs = [cb for cb in self.callbacks if Callback.match_events(cb.event_type, Callback.parse(event))] streams = [CallbackStream() for cb in cbs] tasks = [asyncio.create_task(cb.route(self, event, stream)) for cb, stream in zip(cbs, streams)] return list(zip(streams, tasks))