Using Callbacks#

Callbacks allow you to hook into the agent’s execution flow, enabling you to log, monitor, or modify the agent’s behavior. This section provides an extensive guide to understanding and using callbacks effectively.

Callback System Architecture#

The callback system in Malevich Brain is designed around an event-based architecture:

  1. Events: Specific points in the agent’s execution flow that trigger callbacks

  2. Event Types: Categories and subcategories of events (e.g., “message.user”)

  3. Callback Functions: User-defined functions that respond to specific events

  4. CallbackStream: Special objects for handling streaming events

When an event occurs during agent execution, all registered callbacks that match the event pattern are triggered. Callbacks can be synchronous or asynchronous.

Creating Callbacks#

You can create callbacks using the @callback decorator:

from brain.agents.callback import callback

@callback("message.user")
async def log_user_message(agent, event, data):
    """Log user messages"""
    print(f"User said: {data.content}")

@callback("message.assistant")
async def log_assistant_message(agent, event, data):
    """Log assistant messages"""
    print(f"Assistant said: {data.content}")

The first parameter to the @callback decorator is a pattern that specifies which events the callback should respond to.

Callback Function Parameters#

Each callback function receives three parameters:

  1. agent: The Agent instance that triggered the callback

  2. event: A string representing the event type (e.g., “message.user”)

  3. data: For regular callbacks, the data associated with the event (e.g., a Message object) OR stream: For streaming callbacks, a CallbackStream object

For regular event callbacks:

@callback("message.user")
async def process_user_message(agent, event, data):
    # data is a Message object
    print(f"Message content: {data.content}")
    print(f"Message role: {data.role}")

For streaming event callbacks:

@callback("message_stream.assistant")
async def process_stream(agent, event, stream):
    # stream is a CallbackStream object
    async for chunk in stream:
        # Process each chunk
        pass

Event Types#

Callbacks can respond to the following event types:

  • message.user - Triggered when a user message is received

  • message.assistant - Triggered when the assistant sends a message

  • message.system - Triggered when a system message is used

  • message_stream.assistant - Triggered when the assistant starts streaming a message

You can use wildcards in event patterns:

  • message.* - Matches all message events

  • *.assistant - Matches all assistant events

Event Matching Rules#

The callback system uses pattern matching to determine which callbacks to trigger for a given event:

  1. Exact matches: “message.user” matches only user message events

  2. Wildcard matches: “message.*” matches all message events

  3. Multiple wildcards: “*.assistant” matches all assistant events

Internally, the system parses the event string into a tuple of components and compares them with the callback’s pattern:

# Examples of event parsing
"message.user" -> ("message", "user")
"message.*" -> ("message", "*")
"*.assistant" -> ("*", "assistant")

Using Callbacks with Agents#

To use callbacks with an agent, pass them to the agent’s constructor:

from brain.agents.agent import Agent

# Create the agent with callbacks
agent = Agent(
    llm=llm,
    tools=[],
    callbacks=[
        log_user_message,
        log_assistant_message
    ]
)

You can register multiple callbacks for the same event, and they will all be executed in the order they were registered.

Synchronous vs. Asynchronous Callbacks#

Callbacks can be defined as either synchronous or asynchronous functions:

# Synchronous callback
@callback("message.user")
def sync_callback(agent, event, data):
    print(f"Sync: {data.content}")

# Asynchronous callback
@callback("message.user")
async def async_callback(agent, event, data):
    await asyncio.sleep(0.1)  # Simulate async operation
    print(f"Async: {data.content}")

The Malevich Brain agent will properly handle both types, automatically awaiting async callbacks when needed. However, for best performance with I/O-bound operations (like database access or network calls), asynchronous callbacks are recommended.

Streaming Content with Callbacks#

The message_stream.assistant event provides streaming content:

@callback("message_stream.assistant")
async def stream_handler(agent, event, stream):
    """Handle streaming content"""
    print("Assistant: ", end="", flush=True)
    async for chunk in stream:
        if hasattr(chunk, "chunk"):
            # Print each chunk as it arrives
            print(chunk.chunk, end="", flush=True)
        elif hasattr(chunk, "content") and chunk.content:
            # Handle complete messages
            print(f"\nComplete message: {chunk.content}")

The stream parameter in this callback is a CallbackStream object that yields either:

  1. MessageChunk objects: Small pieces of the message being generated

  2. Message objects: Complete messages

Working with the CallbackStream#

The CallbackStream is an async iterator that provides content as it’s generated. It allows real-time processing of LLM outputs.

@callback("message_stream.assistant")
async def detailed_stream_handler(agent, event, stream):
    """Detailed handling of streaming content"""
    buffer = ""

    # Iterate through the stream
    async for item in stream:
        # Different types in the stream
        if hasattr(item, "chunk"):
            # This is a MessageChunk
            chunk_text = item.chunk
            buffer += chunk_text

            # Process each chunk (e.g., print to console)
            print(chunk_text, end="", flush=True)

            # Example: Analyze content as it streams
            if "error" in chunk_text.lower():
                print("\n[Alert: Error mentioned!]")

        elif hasattr(item, "content") and hasattr(item, "role"):
            # This is a complete Message
            print(f"\n[Complete message received: {len(item.content)} chars]")

    # Stream is complete
    print(f"\n[Stream finished, total length: {len(buffer)}]")

The stream automatically terminates when the entire response has been generated. The .done() method is called internally.

Practical Use Cases#

  1. Logging to a Database:

    import asyncpg
    from datetime import datetime
    
    async def setup_db_pool():
        return await asyncpg.create_pool(
            "postgresql://user:password@localhost/dbname"
        )
    
    # Shared connection pool
    db_pool = asyncio.run(setup_db_pool())
    
    @callback("message.*")
    async def db_logger(agent, event, data):
        """Log all messages to a database"""
        async with db_pool.acquire() as conn:
            await conn.execute(
                """
                INSERT INTO message_logs(role, content, event_type, timestamp)
                VALUES($1, $2, $3, $4)
                """,
                data.role,
                data.content,
                event,
                datetime.now()
            )
    
  2. Real-time UI Updates:

    import websockets
    
    # Active WebSocket connections
    active_connections = set()
    
    @callback("message_stream.assistant")
    async def update_ui(agent, event, stream):
        """Send streaming updates to a UI via WebSocket"""
    
        # Send the start message to all connections
        message = {"type": "stream_start"}
        await broadcast(message)
    
        # Process each chunk
        async for chunk in stream:
            if hasattr(chunk, "chunk"):
                # Send each chunk to connected clients
                message = {
                    "type": "stream_chunk",
                    "content": chunk.chunk
                }
                await broadcast(message)
    
        # Send the end message
        message = {"type": "stream_end"}
        await broadcast(message)
    
    async def broadcast(message):
        """Send a message to all connected WebSocket clients"""
        if active_connections:
            await asyncio.gather(*[
                connection.send_json(message)
                for connection in active_connections
            ])
    
  3. Monitoring Tool Usage:

    import time
    
    # Track tool usage metrics
    tool_metrics = {
        "calls": {},
        "errors": {},
        "execution_time": {}
    }
    
    @callback("message.user")
    async def monitor_tool_calls(agent, event, data):
        """Monitor tool usage"""
        # Only process tool response messages
        if hasattr(data, "tool_call_id") and hasattr(data, "error_message"):
            tool_name = data.content.split(" ")[1]  # Extract tool name from content
    
            # Initialize counters if needed
            if tool_name not in tool_metrics["calls"]:
                tool_metrics["calls"][tool_name] = 0
                tool_metrics["errors"][tool_name] = 0
                tool_metrics["execution_time"][tool_name] = []
    
            # Update metrics
            tool_metrics["calls"][tool_name] += 1
    
            # Track errors
            if data.error_message:
                tool_metrics["errors"][tool_name] += 1
    
    # Report metrics periodically
    @callback("message.assistant")
    async def log_metrics_periodically(agent, event, data):
        # Every 10 assistant messages, log metrics
        if sum(tool_metrics["calls"].values()) % 10 == 0 and sum(tool_metrics["calls"].values()) > 0:
            print("\n--- Tool Usage Metrics ---")
            for tool_name, call_count in tool_metrics["calls"].items():
                error_rate = tool_metrics["errors"][tool_name] / call_count if call_count > 0 else 0
                print(f"{tool_name}: {call_count} calls, {error_rate:.2%} error rate")
            print("-------------------------\n")
    
  4. Token Counting and Cost Tracking:

    import tiktoken
    
    class TokenCounter:
        def __init__(self, model_name="gpt-4o"):
            self.model = model_name
            self.encoder = tiktoken.encoding_for_model(model_name)
            self.total_prompt_tokens = 0
            self.total_completion_tokens = 0
    
        def count_tokens(self, text):
            return len(self.encoder.encode(text))
    
    # Initialize counter
    token_counter = TokenCounter()
    
    @callback("message.user")
    async def count_prompt_tokens(agent, event, data):
        """Count tokens in user messages"""
        tokens = token_counter.count_tokens(data.content)
        token_counter.total_prompt_tokens += tokens
    
    @callback("message.assistant")
    async def count_completion_tokens(agent, event, data):
        """Count tokens in assistant responses"""
        tokens = token_counter.count_tokens(data.content)
        token_counter.total_completion_tokens += tokens
    
        # Report totals occasionally
        if token_counter.total_completion_tokens % 1000 < tokens:
            # Estimate cost (based on approximate OpenAI pricing)
            prompt_cost = token_counter.total_prompt_tokens * 0.00001  # $0.01 per 1K tokens
            completion_cost = token_counter.total_completion_tokens * 0.00003  # $0.03 per 1K tokens
            total_cost = prompt_cost + completion_cost
    
            print(f"\nToken usage: {token_counter.total_prompt_tokens} prompt, " +
                  f"{token_counter.total_completion_tokens} completion")
            print(f"Estimated cost: ${total_cost:.4f}")
    

Advanced Callback Patterns#

  1. Conversation History Tracking:

    # Track conversation state
    conversation_history = []
    
    @callback("message.*")
    async def track_conversation(agent, event, data):
        """Add messages to conversation history"""
        conversation_history.append({
            "role": data.role,
            "content": data.content,
            "timestamp": datetime.now().isoformat()
        })
    
        # Limit history size (keep last 100 messages)
        if len(conversation_history) > 100:
            conversation_history.pop(0)
    
    # Save conversation periodically
    @callback("message.user")
    async def save_conversation(agent, event, data):
        # Save every 10 messages
        if len(conversation_history) % 10 == 0:
            with open(f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", "w") as f:
                json.dump(conversation_history, f, indent=2)
    
  2. Tool Call Lifecycle Monitoring:

    # Track active tool calls
    active_tool_calls = {}
    
    @callback("message.assistant")
    async def track_tool_call_start(agent, event, data):
        """Track when a tool call is initiated"""
        if hasattr(data, "tool_call_id") and hasattr(data, "tool_name"):
            active_tool_calls[data.tool_call_id] = {
                "tool_name": data.tool_name,
                "start_time": time.time(),
                "status": "started"
            }
    
    @callback("message.user")
    async def track_tool_call_end(agent, event, data):
        """Track when a tool call is completed"""
        if hasattr(data, "tool_call_id"):
            call_id = data.tool_call_id
            if call_id in active_tool_calls:
                active_tool_calls[call_id]["end_time"] = time.time()
                active_tool_calls[call_id]["status"] = "completed"
    
                # Calculate duration
                duration = active_tool_calls[call_id]["end_time"] - active_tool_calls[call_id]["start_time"]
                active_tool_calls[call_id]["duration"] = duration
    
                if hasattr(data, "error_message") and data.error_message:
                    active_tool_calls[call_id]["status"] = "error"
                    active_tool_calls[call_id]["error"] = data.error_message
    
                # Log completion
                tool_name = active_tool_calls[call_id]["tool_name"]
                status = active_tool_calls[call_id]["status"]
                print(f"Tool {tool_name} {status} in {duration:.4f}s")
    
  3. Debug Logging Decorator:

    def with_debug_logging(event_pattern):
        """Decorator to add debug logging to a callback"""
        def decorator(callback_func):
            # Create the callback as normal
            cb = callback(event_pattern)(callback_func)
    
            # Create a wrapper around the original route method
            original_route = cb.route
    
            async def route_with_logging(agent, event, data):
                print(f"DEBUG: Callback triggered for event {event}")
                try:
                    # Call the original route method
                    result = await original_route(agent, event, data)
                    print(f"DEBUG: Callback completed successfully")
                    return result
                except Exception as e:
                    print(f"DEBUG: Callback failed with error: {e}")
                    raise
    
            # Replace the route method
            cb.route = route_with_logging
            return cb
        return decorator
    
    # Usage
    @with_debug_logging("message.user")
    async def my_callback(agent, event, data):
        print(f"Processing message: {data.content}")
    
  4. Chaining Callbacks:

    class CallbackChain:
        """Chain multiple callbacks together"""
    
        def __init__(self, event_pattern, callbacks):
            self.event_pattern = event_pattern
            self.callbacks = callbacks
    
        def create(self):
            """Create a single callback that runs all callbacks in the chain"""
            @callback(self.event_pattern)
            async def chained_callback(agent, event, data):
                for cb_func in self.callbacks:
                    result = cb_func(agent, event, data)
                    if result is not None and asyncio.iscoroutine(result):
                        await result
    
            return chained_callback
    
    # Usage
    def log_to_console(agent, event, data):
        print(f"CONSOLE: {data.content}")
    
    async def log_to_database(agent, event, data):
        print(f"DATABASE: Saving {data.content}")
        await asyncio.sleep(0.1)  # Simulate database operation
    
    # Create a chain
    message_chain = CallbackChain("message.user", [
        log_to_console,
        log_to_database
    ]).create()
    
    # Use in an agent
    agent = Agent(llm=llm, tools=[], callbacks=[message_chain])
    

Best Practices#

  1. Performance Considerations:

    • Keep callbacks lightweight and fast to avoid slowing down the agent

    • Use asynchronous callbacks for I/O operations

    • Avoid blocking operations in synchronous callbacks

    • Be mindful of memory usage in callbacks that store data

  2. Error Handling:

    • Implement proper error handling in callbacks to prevent them from crashing the agent

    • Use try/except blocks to catch and log errors

    @callback("message.user")
    async def safe_callback(agent, event, data):
        try:
            # Potentially risky operation
            result = await external_api_call(data.content)
            print(f"API result: {result}")
        except Exception as e:
            print(f"Callback error (non-fatal): {e}")
            # Continue execution without disrupting the agent
    
  3. Resource Management:

    • Initialize expensive resources (database connections, API clients) outside callbacks

    • Use connection pooling for database connections

    • Close resources properly when they’re no longer needed

  4. Debugging Callbacks:

    • Add logging to track callback execution

    • Use unique identifiers for messages and tool calls to trace execution flow

    • Consider using the debug logging decorator pattern shown above

  5. Organization:

    • Group related callbacks together in modules

    • Use descriptive names for callback functions

    • Document the purpose and behavior of each callback

    # In callbacks/logging.py
    @callback("message.*")
    async def log_all_messages(agent, event, data):
        """Log all messages to console with timestamp"""
    
    # In callbacks/metrics.py
    @callback("message.user")
    async def track_tool_usage(agent, event, data):
        """Track and record tool usage metrics"""
    
    # In main.py
    from callbacks.logging import log_all_messages
    from callbacks.metrics import track_tool_usage
    
    agent = Agent(
        llm=llm,
        tools=tools,
        callbacks=[log_all_messages, track_tool_usage]
    )
    

Callback Implementation Internals#

Understanding how callbacks are implemented internally can help you use them more effectively:

  1. Callback Registration:

    When you create an agent with callbacks, they are stored in the agent’s callbacks list:

    # In Agent.__init__
    self.callbacks = callbacks
    
  2. Event Triggering:

    The agent triggers events at specific points using the execute_callback method:

    # In Agent.run
    await self.execute_callback('message.user', message)
    # ... later ...
    await self.execute_callback('message.assistant', message)
    
  3. Callback Execution:

    The execute_callback method routes events to matching callbacks:

    async def execute_callback(self, event: str, data: Any) -> None:
        await asyncio.gather(*[
            callback.route(self, event, data)
            for callback in self.callbacks
        ])
    

    This gathers all the callback coroutines and runs them concurrently with asyncio.gather.

  4. Streaming Implementation:

    For streaming events, a special method execute_stream_callback creates CallbackStream objects:

    async def execute_stream_callback(self, event: str) -> list[tuple[CallbackStream, asyncio.Task[None]]]:
        cbs = [cb for cb in self.callbacks if Callback.match_events(cb.event_type, Callback.parse(event))]
        streams = [CallbackStream() for cb in cbs]
        tasks = [asyncio.create_task(cb.route(self, event, stream))
                 for cb, stream in zip(cbs, streams)]
        return list(zip(streams, tasks))