Source code for atomic_agents.lib.components.agent_memory

import json
import uuid
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Type

from instructor.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field

from atomic_agents.lib.base.base_io_schema import BaseIOSchema


INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)


[docs] class Message(BaseModel): """ Represents a message in the chat history. Attributes: role (str): The role of the message sender (e.g., 'user', 'system', 'tool'). content (BaseIOSchema): The content of the message. turn_id (Optional[str]): Unique identifier for the turn this message belongs to. """ role: str content: BaseIOSchema turn_id: Optional[str] = None
[docs] class AgentMemory: """ Manages the chat history for an AI agent. Attributes: history (List[Message]): A list of messages representing the chat history. max_messages (Optional[int]): Maximum number of messages to keep in history. current_turn_id (Optional[str]): The ID of the current turn. """
[docs] def __init__(self, max_messages: Optional[int] = None): """ Initializes the AgentMemory with an empty history and optional constraints. Args: max_messages (Optional[int]): Maximum number of messages to keep in history. When exceeded, oldest messages are removed first. """ self.history: List[Message] = [] self.max_messages = max_messages self.current_turn_id: Optional[str] = None
[docs] def initialize_turn(self) -> None: """ Initializes a new turn by generating a random turn ID. """ self.current_turn_id = str(uuid.uuid4())
[docs] def add_message( self, role: str, content: BaseIOSchema, ) -> None: """ Adds a message to the chat history and manages overflow. Args: role (str): The role of the message sender. content (BaseIOSchema): The content of the message. """ if self.current_turn_id is None: self.initialize_turn() message = Message( role=role, content=content, turn_id=self.current_turn_id, ) self.history.append(message) self._manage_overflow()
def _manage_overflow(self) -> None: """ Manages the chat history overflow based on max_messages constraint. """ if self.max_messages is not None: while len(self.history) > self.max_messages: self.history.pop(0)
[docs] def get_history(self) -> List[Dict]: """ Retrieves the chat history, handling both regular and multimodal content. Returns: List[Dict]: The list of messages in the chat history as dictionaries. Each dictionary has 'role' and 'content' keys, where 'content' is a list that may contain strings (JSON) or multimodal objects. Note: This method does not support nested multimodal content. If your schema contains nested objects that themselves contain multimodal content, only the top-level multimodal content will be properly processed. """ history = [] for message in self.history: input_content = message.content processed_content = [] for field_name, field in input_content.__class__.model_fields.items(): field_value = getattr(input_content, field_name) if isinstance(field_value, list): has_multimodal_in_list = False for item in field_value: if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES): processed_content.append(item) has_multimodal_in_list = True if not has_multimodal_in_list: processed_content.append(input_content.model_dump_json(include={field_name})) else: if isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES): processed_content.append(field_value) else: processed_content.append(input_content.model_dump_json(include={field_name})) history.append({"role": message.role, "content": processed_content}) return history
[docs] def copy(self) -> "AgentMemory": """ Creates a copy of the chat memory. Returns: AgentMemory: A copy of the chat memory. """ new_memory = AgentMemory(max_messages=self.max_messages) new_memory.load(self.dump()) new_memory.current_turn_id = self.current_turn_id return new_memory
[docs] def get_current_turn_id(self) -> Optional[str]: """ Returns the current turn ID. Returns: Optional[str]: The current turn ID, or None if not set. """ return self.current_turn_id
[docs] def delete_turn_id(self, turn_id: int): """ Delete messages from the memory by its turn ID. Args: turn_id (int): The turn ID of the message to delete. Returns: str: A success message with the deleted turn ID. Raises: ValueError: If the specified turn ID is not found in the memory. """ initial_length = len(self.history) self.history = [msg for msg in self.history if msg.turn_id != turn_id] if len(self.history) == initial_length: raise ValueError(f"Turn ID {turn_id} not found in memory.") # Update current_turn_id if necessary if not len(self.history): self.current_turn_id = None elif turn_id == self.current_turn_id: # Always update to the last message's turn_id self.current_turn_id = self.history[-1].turn_id
[docs] def get_message_count(self) -> int: """ Returns the number of messages in the chat history. Returns: int: The number of messages. """ return len(self.history)
[docs] def dump(self) -> str: """ Serializes the entire AgentMemory instance to a JSON string. Returns: str: A JSON string representation of the AgentMemory. """ serialized_history = [] for message in self.history: content_class = message.content.__class__ serialized_message = { "role": message.role, "content": { "class_name": f"{content_class.__module__}.{content_class.__name__}", "data": message.content.model_dump_json(), }, "turn_id": message.turn_id, } serialized_history.append(serialized_message) memory_data = { "history": serialized_history, "max_messages": self.max_messages, "current_turn_id": self.current_turn_id, } return json.dumps(memory_data)
[docs] def load(self, serialized_data: str) -> None: """ Deserializes a JSON string and loads it into the AgentMemory instance. Args: serialized_data (str): A JSON string representation of the AgentMemory. Raises: ValueError: If the serialized data is invalid or cannot be deserialized. """ try: memory_data = json.loads(serialized_data) self.history = [] self.max_messages = memory_data["max_messages"] self.current_turn_id = memory_data["current_turn_id"] for message_data in memory_data["history"]: content_info = message_data["content"] content_class = self._get_class_from_string(content_info["class_name"]) content_instance = content_class.model_validate_json(content_info["data"]) # Process any Image objects to convert string paths back to Path objects self._process_multimodal_paths(content_instance) message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"]) self.history.append(message) except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e: raise ValueError(f"Invalid serialized data: {e}")
@staticmethod def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]: """ Retrieves a class object from its string representation. Args: class_string (str): The fully qualified class name. Returns: Type[BaseIOSchema]: The class object. Raises: AttributeError: If the class cannot be found. """ module_name, class_name = class_string.rsplit(".", 1) module = __import__(module_name, fromlist=[class_name]) return getattr(module, class_name) def _process_multimodal_paths(self, obj): """ Process multimodal objects to convert string paths to Path objects. Note: this is necessary only for PDF and Image instructor types. The from_path behavior is slightly different for Audio as it keeps the source as a string. Args: obj: The object to process. """ if isinstance(obj, (Image, PDF)) and isinstance(obj.source, str): # Check if the string looks like a file path (not a URL or base64 data) if not obj.source.startswith(("http://", "https://", "data:")): obj.source = Path(obj.source) elif isinstance(obj, list): # Process each item in the list for item in obj: self._process_multimodal_paths(item) elif isinstance(obj, dict): # Process each value in the dictionary for value in obj.values(): self._process_multimodal_paths(value) elif hasattr(obj, "model_fields"): # Process each field of the Pydantic model for field_name in obj.model_fields: if hasattr(obj, field_name): self._process_multimodal_paths(getattr(obj, field_name)) elif hasattr(obj, "__dict__") and not isinstance(obj, Enum): # Process each attribute of the object for attr_name, attr_value in obj.__dict__.items(): if attr_name != "__pydantic_fields_set__": # Skip pydantic internal fields self._process_multimodal_paths(attr_value)
if __name__ == "__main__": import instructor from typing import List as TypeList, Dict as TypeDict import os # Define complex test schemas class NestedSchema(BaseIOSchema): """A nested schema for testing""" nested_field: str = Field(..., description="A nested field") nested_int: int = Field(..., description="A nested integer") class ComplexInputSchema(BaseIOSchema): """Complex Input Schema""" text_field: str = Field(..., description="A text field") number_field: float = Field(..., description="A number field") list_field: TypeList[str] = Field(..., description="A list of strings") nested_field: NestedSchema = Field(..., description="A nested schema") class ComplexOutputSchema(BaseIOSchema): """Complex Output Schema""" response_text: str = Field(..., description="A response text") calculated_value: int = Field(..., description="A calculated value") data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas") # Add a new multimodal schema for testing class MultimodalSchema(BaseIOSchema): """Schema for testing multimodal content""" instruction_text: str = Field(..., description="The instruction text") images: List[instructor.Image] = Field(..., description="The images to analyze") # Create and populate the original memory with complex data original_memory = AgentMemory(max_messages=10) # Add a complex input message original_memory.add_message( "user", ComplexInputSchema( text_field="Hello, this is a complex input", number_field=3.14159, list_field=["item1", "item2", "item3"], nested_field=NestedSchema(nested_field="Nested input", nested_int=42), ), ) # Add a complex output message original_memory.add_message( "assistant", ComplexOutputSchema( response_text="This is a complex response", calculated_value=100, data_dict={ "key1": NestedSchema(nested_field="Nested output 1", nested_int=10), "key2": NestedSchema(nested_field="Nested output 2", nested_int=20), }, ), ) # Test multimodal functionality if test image exists test_image_path = os.path.join("test_images", "test.jpg") if os.path.exists(test_image_path): # Add a multimodal message original_memory.add_message( "user", MultimodalSchema( instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)] ), ) # Continue with existing tests... dumped_data = original_memory.dump() print("Dumped data:") print(dumped_data) # Create a new memory and load the dumped data loaded_memory = AgentMemory() loaded_memory.load(dumped_data) # Print detailed information about the loaded memory print("\nLoaded memory details:") for i, message in enumerate(loaded_memory.history): print(f"\nMessage {i + 1}:") print(f"Role: {message.role}") print(f"Turn ID: {message.turn_id}") print(f"Content type: {type(message.content).__name__}") print("Content:") for field, value in message.content.model_dump().items(): print(f" {field}: {value}") # Final verification print("\nFinal verification:") print(f"Max messages: {loaded_memory.max_messages}") print(f"Current turn ID: {loaded_memory.get_current_turn_id()}") print("Last message content:") last_message = loaded_memory.history[-1] print(last_message.content.model_dump())