Chat Memory with Callbacks#

This example demonstrates how to create an agent with memory by using callbacks to track conversation history and retrieve past messages.

Building a Memory System#

First, let’s create a memory system to store and retrieve conversation history:

from brain.agents.tool import tool, Output
from brain.agents.callback import callback
from pydantic import BaseModel
from typing import List, Optional
import json
from datetime import datetime
import os

# Define models for our memory system
class MemoryEntry(BaseModel):
    role: str
    content: str
    timestamp: str

class MemorySearchResult(Output):
    entries: List[MemoryEntry]
    total_found: int

# Create a simple file-based memory store
class ConversationMemory:
    def __init__(self, memory_file="conversation_memory.json"):
        self.memory_file = memory_file
        self.entries: List[MemoryEntry] = []
        self._load_memory()

    def _load_memory(self):
        """Load memory from file if it exists"""
        if os.path.exists(self.memory_file):
            try:
                with open(self.memory_file, "r") as f:
                    data = json.load(f)
                    self.entries = [MemoryEntry(**entry) for entry in data]
            except Exception as e:
                print(f"Error loading memory: {e}")

    def _save_memory(self):
        """Save memory to file"""
        with open(self.memory_file, "w") as f:
            json.dump([entry.dict() for entry in self.entries], f, indent=2)

    def add_entry(self, role: str, content: str):
        """Add a new memory entry"""
        entry = MemoryEntry(
            role=role,
            content=content,
            timestamp=datetime.now().isoformat()
        )
        self.entries.append(entry)
        self._save_memory()

    def search(self, query: str, limit: int = 5) -> List[MemoryEntry]:
        """Search memory for entries containing the query"""
        results = []

        # Very simple search implementation
        for entry in reversed(self.entries):
            if query.lower() in entry.content.lower():
                results.append(entry)
                if len(results) >= limit:
                    break

        return results

    def get_recent(self, count: int = 5) -> List[MemoryEntry]:
        """Get the most recent entries"""
        return list(reversed(self.entries[-count:]))

# Create a global instance of our memory system
memory = ConversationMemory()

Setting Up Memory Callbacks#

Next, let’s create callbacks to automatically track the conversation:

# Create callbacks to track all messages
@callback("message.*")
async def track_conversation_history(agent, event, data):
    """Add all messages to conversation history"""
    # Skip storing tool response messages
    if hasattr(data, "tool_call_id"):
        return

    # Add the message to memory
    memory.add_entry(data.role, data.content)

Creating the Agent with Memory#

Now, let’s create an agent that uses our memory system:

import asyncio
import os
from brain.agents.agent import Agent
from brain.agents.llm.openai import OpenAIBaseLLM
from brain.agents.callback import callback

@callback("message_stream.assistant")
async def stream_to_console(agent, event, stream):
    print("\nAssistant: ", end="", flush=True)
    async for chunk in stream:
        if hasattr(chunk, "chunk"):
            print(chunk.chunk, end="", flush=True)

async def main():
    # Initialize the LLM
    llm = OpenAIBaseLLM(
        api_key=os.environ.get("OPENAI_API_KEY"),
        default_model="gpt-4o"
    )

    # Create the agent with memory tools and the memory callback
    agent = Agent(
        llm=llm,
        tools=[search_memory, get_recent_messages],
        instructions="""
        You are a helpful assistant with memory. You can remember previous parts of the
        conversation and reference them when appropriate.

        You have access to the following memory tools:
        - search_memory: Search for specific information in past conversations
        - get_recent_messages: Get the most recent messages from the conversation

        Use these tools when the user refers to something mentioned earlier or when you
        need to check past information to provide a consistent response.
        """,
        callbacks=[
            track_conversation_history,
            stream_to_console
        ]
    )

    # Run a conversation loop
    print("Memory Assistant (type 'exit' to quit)")
    while True:
        # Get user input
        user_input = input("\nYou: ")
        if user_input.lower() == "exit":
            break

        # Process with the agent
        await agent.run(user_input)
        print()  # Add a newline after the response

if __name__ == "__main__":
    asyncio.run(main())

Complete Example#

Here’s the complete example with all components integrated:

import asyncio
import json
import os
from datetime import datetime
from typing import List, Optional

from brain.agents.agent import Agent
from brain.agents.callback import callback
from brain.agents.llm.openai import OpenAIBaseLLM
from brain.agents.tool import tool, Output
from pydantic import BaseModel

# Define models for our memory system
class MemoryEntry(BaseModel):
    role: str
    content: str
    timestamp: str

class MemorySearchResult(Output):
    entries: List[MemoryEntry]
    total_found: int

# Create a simple file-based memory store
class ConversationMemory:
    def __init__(self, memory_file="conversation_memory.json"):
        self.memory_file = memory_file
        self.entries: List[MemoryEntry] = []
        self._load_memory()

    def _load_memory(self):
        """Load memory from file if it exists"""
        if os.path.exists(self.memory_file):
            try:
                with open(self.memory_file, "r") as f:
                    data = json.load(f)
                    self.entries = [MemoryEntry(**entry) for entry in data]
            except Exception as e:
                print(f"Error loading memory: {e}")

    def _save_memory(self):
        """Save memory to file"""
        with open(self.memory_file, "w") as f:
            json.dump([entry.model_dump() for entry in self.entries], f, indent=2)

    def add_entry(self, role: str, content: str):
        """Add a new memory entry"""
        entry = MemoryEntry(
            role=role,
            content=content,
            timestamp=datetime.now().isoformat()
        )
        self.entries.append(entry)
        self._save_memory()

    def search(self, query: str, limit: int = 5) -> List[MemoryEntry]:
        """Search memory for entries containing the query"""
        results = []

        # Very simple search implementation
        for entry in reversed(self.entries):
            if query.lower() in entry.content.lower():
                results.append(entry)
                if len(results) >= limit:
                    break

        return results

    def get_recent(self, count: int = 5) -> List[MemoryEntry]:
        """Get the most recent entries"""
        return list(reversed(self.entries[-count:])) if self.entries else []

# Create a global instance of our memory system
memory = ConversationMemory()

# Define the memory query model
class MemoryQuery(BaseModel):
    query: str
    limit: int = 5

# Create memory-related tools
@tool()
def search_memory(input: MemoryQuery) -> MemorySearchResult:
    """
    Search previous conversation history for relevant information

    Args:
        input: A MemoryQuery object with the search query and optional limit

    Returns:
        MemorySearchResult with matching entries
    """
    results = memory.search(input.query, input.limit)
    return MemorySearchResult(
        entries=results,
        total_found=len(results)
    )

@tool()
def get_recent_messages(limit: int = 5) -> MemorySearchResult:
    """
    Retrieve the most recent messages from the conversation

    Args:
        limit: Maximum number of messages to retrieve

    Returns:
        MemorySearchResult with the most recent messages
    """
    results = memory.get_recent(limit)
    return MemorySearchResult(
        entries=results,
        total_found=len(results)
    )

# Create callbacks for tracking conversation and streaming
@callback("message.*")
async def track_conversation_history(agent, event, data):
    """Add all messages to conversation history"""
    # Skip storing tool response messages
    if hasattr(data, "tool_call_id"):
        return

    # Add the message to memory
    memory.add_entry(data.role, data.content)

@callback("message_stream.assistant")
async def stream_to_console(agent, event, stream):
    print("\nAssistant: ", end="", flush=True)
    async for chunk in stream:
        if hasattr(chunk, "chunk"):
            print(chunk.chunk, end="", flush=True)

async def main():
    # Initialize the LLM
    llm = OpenAIBaseLLM(
        api_key=os.environ.get("OPENAI_API_KEY"),
        default_model="gpt-4o"
    )

    # Create the agent with memory tools and the memory callback
    agent = Agent(
        llm=llm,
        tools=[search_memory, get_recent_messages],
        instructions="""
        You are a helpful assistant with memory. You can remember previous parts of the
        conversation and reference them when appropriate.

        You have access to the following memory tools:
        - search_memory: Search for specific information in past conversations
        - get_recent_messages: Get the most recent messages from the conversation

        Use these tools when the user refers to something mentioned earlier or when you
        need to check past information to provide a consistent response.
        """,
        callbacks=[
            track_conversation_history,
            stream_to_console
        ]
    )

    # Run a conversation loop
    print("Memory Assistant (type 'exit' to quit)")
    print("This assistant remembers your conversation history and can refer back to it.")
    print("Try asking about something mentioned earlier in the conversation!")

    while True:
        # Get user input
        user_input = input("\nYou: ")
        if user_input.lower() == "exit":
            break

        # Process with the agent
        await agent.run(user_input)
        print()  # Add a newline after the response

if __name__ == "__main__":
    asyncio.run(main())

Alternative Approach: Stateful Tool#

Instead of using callbacks, you can also implement memory using a stateful tool:

from brain.agents.tool import tool, StateContext

# Define a state schema
class ChatState(BaseModel):
    history: List[MemoryEntry] = []

@tool(stateful=True, state_schema=ChatState)
def remember_and_search(query: str, context: StateContext[ChatState]) -> str:
    """
    Search conversation history and remember the query

    Args:
        query: The search query
        context: The state context containing conversation history

    Returns:
        The search results as a formatted string
    """
    # Access the current state
    state = context.state

    # Search history
    results = []
    for entry in reversed(state.history):
        if query.lower() in entry.content.lower():
            results.append(entry)
            if len(results) >= 3:
                break

    # Format results
    if results:
        result_text = "I found these relevant messages:\n\n"
        for i, entry in enumerate(results):
            result_text += f"{i+1}. {entry.role}: {entry.content}\n"
        return result_text
    else:
        return "I couldn't find any relevant information in our conversation history."

This approach has the advantage of using the built-in state management feature of the Tool system, but it requires more manual memory management compared to the callback approach.