Using Callbacks#
Callbacks allow you to hook into the agent’s execution flow, enabling you to log, monitor, or modify the agent’s behavior. This section provides an extensive guide to understanding and using callbacks effectively.
Callback System Architecture#
The callback system in Malevich Brain is designed around an event-based architecture:
Events: Specific points in the agent’s execution flow that trigger callbacks
Event Types: Categories and subcategories of events (e.g., “message.user”)
Callback Functions: User-defined functions that respond to specific events
CallbackStream: Special objects for handling streaming events
When an event occurs during agent execution, all registered callbacks that match the event pattern are triggered. Callbacks can be synchronous or asynchronous.
Creating Callbacks#
You can create callbacks using the @callback decorator:
from brain.agents.callback import callback
@callback("message.user")
async def log_user_message(agent, event, data):
"""Log user messages"""
print(f"User said: {data.content}")
@callback("message.assistant")
async def log_assistant_message(agent, event, data):
"""Log assistant messages"""
print(f"Assistant said: {data.content}")
The first parameter to the @callback decorator is a pattern that specifies which events the callback should respond to.
Callback Function Parameters#
Each callback function receives three parameters:
agent: The Agent instance that triggered the callback
event: A string representing the event type (e.g., “message.user”)
data: For regular callbacks, the data associated with the event (e.g., a Message object) OR stream: For streaming callbacks, a CallbackStream object
For regular event callbacks:
@callback("message.user")
async def process_user_message(agent, event, data):
# data is a Message object
print(f"Message content: {data.content}")
print(f"Message role: {data.role}")
For streaming event callbacks:
@callback("message_stream.assistant")
async def process_stream(agent, event, stream):
# stream is a CallbackStream object
async for chunk in stream:
# Process each chunk
pass
Event Types#
Callbacks can respond to the following event types:
message.user- Triggered when a user message is receivedmessage.assistant- Triggered when the assistant sends a messagemessage.system- Triggered when a system message is usedmessage_stream.assistant- Triggered when the assistant starts streaming a message
You can use wildcards in event patterns:
message.*- Matches all message events*.assistant- Matches all assistant events
Event Matching Rules#
The callback system uses pattern matching to determine which callbacks to trigger for a given event:
Exact matches: “message.user” matches only user message events
Wildcard matches: “message.*” matches all message events
Multiple wildcards: “*.assistant” matches all assistant events
Internally, the system parses the event string into a tuple of components and compares them with the callback’s pattern:
# Examples of event parsing
"message.user" -> ("message", "user")
"message.*" -> ("message", "*")
"*.assistant" -> ("*", "assistant")
Using Callbacks with Agents#
To use callbacks with an agent, pass them to the agent’s constructor:
from brain.agents.agent import Agent
# Create the agent with callbacks
agent = Agent(
llm=llm,
tools=[],
callbacks=[
log_user_message,
log_assistant_message
]
)
You can register multiple callbacks for the same event, and they will all be executed in the order they were registered.
Synchronous vs. Asynchronous Callbacks#
Callbacks can be defined as either synchronous or asynchronous functions:
# Synchronous callback
@callback("message.user")
def sync_callback(agent, event, data):
print(f"Sync: {data.content}")
# Asynchronous callback
@callback("message.user")
async def async_callback(agent, event, data):
await asyncio.sleep(0.1) # Simulate async operation
print(f"Async: {data.content}")
The Malevich Brain agent will properly handle both types, automatically awaiting async callbacks when needed. However, for best performance with I/O-bound operations (like database access or network calls), asynchronous callbacks are recommended.
Streaming Content with Callbacks#
The message_stream.assistant event provides streaming content:
@callback("message_stream.assistant")
async def stream_handler(agent, event, stream):
"""Handle streaming content"""
print("Assistant: ", end="", flush=True)
async for chunk in stream:
if hasattr(chunk, "chunk"):
# Print each chunk as it arrives
print(chunk.chunk, end="", flush=True)
elif hasattr(chunk, "content") and chunk.content:
# Handle complete messages
print(f"\nComplete message: {chunk.content}")
The stream parameter in this callback is a CallbackStream object that yields either:
MessageChunkobjects: Small pieces of the message being generatedMessageobjects: Complete messages
Working with the CallbackStream#
The CallbackStream is an async iterator that provides content as it’s generated. It allows real-time processing of LLM outputs.
@callback("message_stream.assistant")
async def detailed_stream_handler(agent, event, stream):
"""Detailed handling of streaming content"""
buffer = ""
# Iterate through the stream
async for item in stream:
# Different types in the stream
if hasattr(item, "chunk"):
# This is a MessageChunk
chunk_text = item.chunk
buffer += chunk_text
# Process each chunk (e.g., print to console)
print(chunk_text, end="", flush=True)
# Example: Analyze content as it streams
if "error" in chunk_text.lower():
print("\n[Alert: Error mentioned!]")
elif hasattr(item, "content") and hasattr(item, "role"):
# This is a complete Message
print(f"\n[Complete message received: {len(item.content)} chars]")
# Stream is complete
print(f"\n[Stream finished, total length: {len(buffer)}]")
The stream automatically terminates when the entire response has been generated. The .done() method is called internally.
Practical Use Cases#
Logging to a Database:
import asyncpg from datetime import datetime async def setup_db_pool(): return await asyncpg.create_pool( "postgresql://user:password@localhost/dbname" ) # Shared connection pool db_pool = asyncio.run(setup_db_pool()) @callback("message.*") async def db_logger(agent, event, data): """Log all messages to a database""" async with db_pool.acquire() as conn: await conn.execute( """ INSERT INTO message_logs(role, content, event_type, timestamp) VALUES($1, $2, $3, $4) """, data.role, data.content, event, datetime.now() )
Real-time UI Updates:
import websockets # Active WebSocket connections active_connections = set() @callback("message_stream.assistant") async def update_ui(agent, event, stream): """Send streaming updates to a UI via WebSocket""" # Send the start message to all connections message = {"type": "stream_start"} await broadcast(message) # Process each chunk async for chunk in stream: if hasattr(chunk, "chunk"): # Send each chunk to connected clients message = { "type": "stream_chunk", "content": chunk.chunk } await broadcast(message) # Send the end message message = {"type": "stream_end"} await broadcast(message) async def broadcast(message): """Send a message to all connected WebSocket clients""" if active_connections: await asyncio.gather(*[ connection.send_json(message) for connection in active_connections ])
Monitoring Tool Usage:
import time # Track tool usage metrics tool_metrics = { "calls": {}, "errors": {}, "execution_time": {} } @callback("message.user") async def monitor_tool_calls(agent, event, data): """Monitor tool usage""" # Only process tool response messages if hasattr(data, "tool_call_id") and hasattr(data, "error_message"): tool_name = data.content.split(" ")[1] # Extract tool name from content # Initialize counters if needed if tool_name not in tool_metrics["calls"]: tool_metrics["calls"][tool_name] = 0 tool_metrics["errors"][tool_name] = 0 tool_metrics["execution_time"][tool_name] = [] # Update metrics tool_metrics["calls"][tool_name] += 1 # Track errors if data.error_message: tool_metrics["errors"][tool_name] += 1 # Report metrics periodically @callback("message.assistant") async def log_metrics_periodically(agent, event, data): # Every 10 assistant messages, log metrics if sum(tool_metrics["calls"].values()) % 10 == 0 and sum(tool_metrics["calls"].values()) > 0: print("\n--- Tool Usage Metrics ---") for tool_name, call_count in tool_metrics["calls"].items(): error_rate = tool_metrics["errors"][tool_name] / call_count if call_count > 0 else 0 print(f"{tool_name}: {call_count} calls, {error_rate:.2%} error rate") print("-------------------------\n")
Token Counting and Cost Tracking:
import tiktoken class TokenCounter: def __init__(self, model_name="gpt-4o"): self.model = model_name self.encoder = tiktoken.encoding_for_model(model_name) self.total_prompt_tokens = 0 self.total_completion_tokens = 0 def count_tokens(self, text): return len(self.encoder.encode(text)) # Initialize counter token_counter = TokenCounter() @callback("message.user") async def count_prompt_tokens(agent, event, data): """Count tokens in user messages""" tokens = token_counter.count_tokens(data.content) token_counter.total_prompt_tokens += tokens @callback("message.assistant") async def count_completion_tokens(agent, event, data): """Count tokens in assistant responses""" tokens = token_counter.count_tokens(data.content) token_counter.total_completion_tokens += tokens # Report totals occasionally if token_counter.total_completion_tokens % 1000 < tokens: # Estimate cost (based on approximate OpenAI pricing) prompt_cost = token_counter.total_prompt_tokens * 0.00001 # $0.01 per 1K tokens completion_cost = token_counter.total_completion_tokens * 0.00003 # $0.03 per 1K tokens total_cost = prompt_cost + completion_cost print(f"\nToken usage: {token_counter.total_prompt_tokens} prompt, " + f"{token_counter.total_completion_tokens} completion") print(f"Estimated cost: ${total_cost:.4f}")
Advanced Callback Patterns#
Conversation History Tracking:
# Track conversation state conversation_history = [] @callback("message.*") async def track_conversation(agent, event, data): """Add messages to conversation history""" conversation_history.append({ "role": data.role, "content": data.content, "timestamp": datetime.now().isoformat() }) # Limit history size (keep last 100 messages) if len(conversation_history) > 100: conversation_history.pop(0) # Save conversation periodically @callback("message.user") async def save_conversation(agent, event, data): # Save every 10 messages if len(conversation_history) % 10 == 0: with open(f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", "w") as f: json.dump(conversation_history, f, indent=2)
Tool Call Lifecycle Monitoring:
# Track active tool calls active_tool_calls = {} @callback("message.assistant") async def track_tool_call_start(agent, event, data): """Track when a tool call is initiated""" if hasattr(data, "tool_call_id") and hasattr(data, "tool_name"): active_tool_calls[data.tool_call_id] = { "tool_name": data.tool_name, "start_time": time.time(), "status": "started" } @callback("message.user") async def track_tool_call_end(agent, event, data): """Track when a tool call is completed""" if hasattr(data, "tool_call_id"): call_id = data.tool_call_id if call_id in active_tool_calls: active_tool_calls[call_id]["end_time"] = time.time() active_tool_calls[call_id]["status"] = "completed" # Calculate duration duration = active_tool_calls[call_id]["end_time"] - active_tool_calls[call_id]["start_time"] active_tool_calls[call_id]["duration"] = duration if hasattr(data, "error_message") and data.error_message: active_tool_calls[call_id]["status"] = "error" active_tool_calls[call_id]["error"] = data.error_message # Log completion tool_name = active_tool_calls[call_id]["tool_name"] status = active_tool_calls[call_id]["status"] print(f"Tool {tool_name} {status} in {duration:.4f}s")
Debug Logging Decorator:
def with_debug_logging(event_pattern): """Decorator to add debug logging to a callback""" def decorator(callback_func): # Create the callback as normal cb = callback(event_pattern)(callback_func) # Create a wrapper around the original route method original_route = cb.route async def route_with_logging(agent, event, data): print(f"DEBUG: Callback triggered for event {event}") try: # Call the original route method result = await original_route(agent, event, data) print(f"DEBUG: Callback completed successfully") return result except Exception as e: print(f"DEBUG: Callback failed with error: {e}") raise # Replace the route method cb.route = route_with_logging return cb return decorator # Usage @with_debug_logging("message.user") async def my_callback(agent, event, data): print(f"Processing message: {data.content}")
Chaining Callbacks:
class CallbackChain: """Chain multiple callbacks together""" def __init__(self, event_pattern, callbacks): self.event_pattern = event_pattern self.callbacks = callbacks def create(self): """Create a single callback that runs all callbacks in the chain""" @callback(self.event_pattern) async def chained_callback(agent, event, data): for cb_func in self.callbacks: result = cb_func(agent, event, data) if result is not None and asyncio.iscoroutine(result): await result return chained_callback # Usage def log_to_console(agent, event, data): print(f"CONSOLE: {data.content}") async def log_to_database(agent, event, data): print(f"DATABASE: Saving {data.content}") await asyncio.sleep(0.1) # Simulate database operation # Create a chain message_chain = CallbackChain("message.user", [ log_to_console, log_to_database ]).create() # Use in an agent agent = Agent(llm=llm, tools=[], callbacks=[message_chain])
Best Practices#
Performance Considerations:
Keep callbacks lightweight and fast to avoid slowing down the agent
Use asynchronous callbacks for I/O operations
Avoid blocking operations in synchronous callbacks
Be mindful of memory usage in callbacks that store data
Error Handling:
Implement proper error handling in callbacks to prevent them from crashing the agent
Use try/except blocks to catch and log errors
@callback("message.user") async def safe_callback(agent, event, data): try: # Potentially risky operation result = await external_api_call(data.content) print(f"API result: {result}") except Exception as e: print(f"Callback error (non-fatal): {e}") # Continue execution without disrupting the agent
Resource Management:
Initialize expensive resources (database connections, API clients) outside callbacks
Use connection pooling for database connections
Close resources properly when they’re no longer needed
Debugging Callbacks:
Add logging to track callback execution
Use unique identifiers for messages and tool calls to trace execution flow
Consider using the debug logging decorator pattern shown above
Organization:
Group related callbacks together in modules
Use descriptive names for callback functions
Document the purpose and behavior of each callback
# In callbacks/logging.py @callback("message.*") async def log_all_messages(agent, event, data): """Log all messages to console with timestamp""" # In callbacks/metrics.py @callback("message.user") async def track_tool_usage(agent, event, data): """Track and record tool usage metrics""" # In main.py from callbacks.logging import log_all_messages from callbacks.metrics import track_tool_usage agent = Agent( llm=llm, tools=tools, callbacks=[log_all_messages, track_tool_usage] )
Callback Implementation Internals#
Understanding how callbacks are implemented internally can help you use them more effectively:
Callback Registration:
When you create an agent with callbacks, they are stored in the agent’s
callbackslist:# In Agent.__init__ self.callbacks = callbacks
Event Triggering:
The agent triggers events at specific points using the
execute_callbackmethod:# In Agent.run await self.execute_callback('message.user', message) # ... later ... await self.execute_callback('message.assistant', message)
Callback Execution:
The
execute_callbackmethod routes events to matching callbacks:async def execute_callback(self, event: str, data: Any) -> None: await asyncio.gather(*[ callback.route(self, event, data) for callback in self.callbacks ])
This gathers all the callback coroutines and runs them concurrently with
asyncio.gather.Streaming Implementation:
For streaming events, a special method
execute_stream_callbackcreatesCallbackStreamobjects:async def execute_stream_callback(self, event: str) -> list[tuple[CallbackStream, asyncio.Task[None]]]: cbs = [cb for cb in self.callbacks if Callback.match_events(cb.event_type, Callback.parse(event))] streams = [CallbackStream() for cb in cbs] tasks = [asyncio.create_task(cb.route(self, event, stream)) for cb, stream in zip(cbs, streams)] return list(zip(streams, tasks))