Source code for atomic_agents.lib.components.agent_memory

import uuid
import json
from typing import Dict, List, Optional, Type
from pydantic import BaseModel, Field
from atomic_agents.lib.base.base_io_schema import BaseIOSchema


[docs] class Message(BaseModel): """ Represents a message in the chat history. Attributes: role (str): The role of the message sender (e.g., 'user', 'system', 'tool'). content (BaseIOSchema): The content of the message. turn_id (Optional[str]): Unique identifier for the turn this message belongs to. """ role: str content: BaseIOSchema turn_id: Optional[str] = None
[docs] class AgentMemory: """ Manages the chat history for an AI agent. Attributes: history (List[Message]): A list of messages representing the chat history. max_messages (Optional[int]): Maximum number of messages to keep in history. current_turn_id (Optional[str]): The ID of the current turn. """
[docs] def __init__(self, max_messages: Optional[int] = None): """ Initializes the AgentMemory with an empty history and optional constraints. Args: max_messages (Optional[int]): Maximum number of messages to keep in history. When exceeded, oldest messages are removed first. """ self.history: List[Message] = [] self.max_messages = max_messages self.current_turn_id: Optional[str] = None
[docs] def initialize_turn(self) -> None: """ Initializes a new turn by generating a random turn ID. """ self.current_turn_id = str(uuid.uuid4())
[docs] def add_message( self, role: str, content: BaseIOSchema, ) -> None: """ Adds a message to the chat history and manages overflow. Args: role (str): The role of the message sender. content (BaseIOSchema): The content of the message. """ if self.current_turn_id is None: self.initialize_turn() message = Message( role=role, content=content, turn_id=self.current_turn_id, ) self.history.append(message) self._manage_overflow()
def _manage_overflow(self) -> None: """ Manages the chat history overflow based on max_messages constraint. """ if self.max_messages is not None: while len(self.history) > self.max_messages: self.history.pop(0)
[docs] def get_history(self) -> List[Dict]: """ Retrieves the chat history, handling both regular and multimodal content. Returns: List[Dict]: The list of messages in the chat history as dictionaries. """ history = [] for message in self.history: content = message.content message_content = content.model_dump() images = [] image_keys = [] for key, value in message_content.items(): if isinstance(value, list): for list_item in value: if isinstance(list_item, dict) and list_item.get("media_type", "").startswith("image"): images.extend(value) image_keys.append(key) break if isinstance(value, dict) and value.get("media_type", "").startswith("image"): images.append(value) image_keys.append(key) if len(images) > 0: # For multimodal content, format as a list with text and images # delete image keys from model images = [] for key in image_keys: message_content.pop(key) image_content = getattr(content, key) if isinstance(image_content, list): images.extend(image_content) else: images.append(image_content) history.append({"role": message.role, "content": [json.dumps(message_content), *images]}) else: # For regular content, serialize to JSON string history.append({"role": message.role, "content": json.dumps(content.model_dump(mode="json"))}) return history
[docs] def copy(self) -> "AgentMemory": """ Creates a copy of the chat memory. Returns: AgentMemory: A copy of the chat memory. """ new_memory = AgentMemory(max_messages=self.max_messages) new_memory.load(self.dump()) new_memory.current_turn_id = self.current_turn_id return new_memory
[docs] def get_current_turn_id(self) -> Optional[str]: """ Returns the current turn ID. Returns: Optional[str]: The current turn ID, or None if not set. """ return self.current_turn_id
[docs] def delete_turn_id(self, turn_id: int): """ Delete messages from the memory by its turn ID. Args: turn_id (int): The turn ID of the message to delete. Returns: str: A success message with the deleted turn ID. Raises: ValueError: If the specified turn ID is not found in the memory. """ initial_length = len(self.history) self.history = [msg for msg in self.history if msg.turn_id != turn_id] if len(self.history) == initial_length: raise ValueError(f"Turn ID {turn_id} not found in memory.") # Update current_turn_id if necessary if not len(self.history): self.current_turn_id = None elif turn_id == self.current_turn_id: # Always update to the last message's turn_id self.current_turn_id = self.history[-1].turn_id
[docs] def get_message_count(self) -> int: """ Returns the number of messages in the chat history. Returns: int: The number of messages. """ return len(self.history)
[docs] def dump(self) -> str: """ Serializes the entire AgentMemory instance to a JSON string. Returns: str: A JSON string representation of the AgentMemory. """ serialized_history = [] for message in self.history: content_class = message.content.__class__ serialized_message = { "role": message.role, "content": { "class_name": f"{content_class.__module__}.{content_class.__name__}", "data": message.content.model_dump(), }, "turn_id": message.turn_id, } serialized_history.append(serialized_message) memory_data = { "history": serialized_history, "max_messages": self.max_messages, "current_turn_id": self.current_turn_id, } return json.dumps(memory_data)
[docs] def load(self, serialized_data: str) -> None: """ Deserializes a JSON string and loads it into the AgentMemory instance. Args: serialized_data (str): A JSON string representation of the AgentMemory. Raises: ValueError: If the serialized data is invalid or cannot be deserialized. """ try: memory_data = json.loads(serialized_data) self.history = [] self.max_messages = memory_data["max_messages"] self.current_turn_id = memory_data["current_turn_id"] for message_data in memory_data["history"]: content_info = message_data["content"] content_class = self._get_class_from_string(content_info["class_name"]) content_instance = content_class(**content_info["data"]) message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"]) self.history.append(message) except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e: raise ValueError(f"Invalid serialized data: {e}")
@staticmethod def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]: """ Retrieves a class object from its string representation. Args: class_string (str): The fully qualified class name. Returns: Type[BaseIOSchema]: The class object. Raises: AttributeError: If the class cannot be found. """ module_name, class_name = class_string.rsplit(".", 1) module = __import__(module_name, fromlist=[class_name]) return getattr(module, class_name)
if __name__ == "__main__": import instructor from typing import List as TypeList, Dict as TypeDict import os # Define complex test schemas class NestedSchema(BaseIOSchema): """A nested schema for testing""" nested_field: str = Field(..., description="A nested field") nested_int: int = Field(..., description="A nested integer") class ComplexInputSchema(BaseIOSchema): """Complex Input Schema""" text_field: str = Field(..., description="A text field") number_field: float = Field(..., description="A number field") list_field: TypeList[str] = Field(..., description="A list of strings") nested_field: NestedSchema = Field(..., description="A nested schema") class ComplexOutputSchema(BaseIOSchema): """Complex Output Schema""" response_text: str = Field(..., description="A response text") calculated_value: int = Field(..., description="A calculated value") data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas") # Add a new multimodal schema for testing class MultimodalSchema(BaseIOSchema): """Schema for testing multimodal content""" instruction_text: str = Field(..., description="The instruction text") images: List[instructor.Image] = Field(..., description="The images to analyze") # Create and populate the original memory with complex data original_memory = AgentMemory(max_messages=10) # Add a complex input message original_memory.add_message( "user", ComplexInputSchema( text_field="Hello, this is a complex input", number_field=3.14159, list_field=["item1", "item2", "item3"], nested_field=NestedSchema(nested_field="Nested input", nested_int=42), ), ) # Add a complex output message original_memory.add_message( "assistant", ComplexOutputSchema( response_text="This is a complex response", calculated_value=100, data_dict={ "key1": NestedSchema(nested_field="Nested output 1", nested_int=10), "key2": NestedSchema(nested_field="Nested output 2", nested_int=20), }, ), ) # Test multimodal functionality if test image exists test_image_path = os.path.join("test_images", "test.jpg") if os.path.exists(test_image_path): # Add a multimodal message original_memory.add_message( "user", MultimodalSchema( instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)] ), ) # Continue with existing tests... dumped_data = original_memory.dump() print("Dumped data:") print(dumped_data) # Create a new memory and load the dumped data loaded_memory = AgentMemory() loaded_memory.load(dumped_data) # Print detailed information about the loaded memory print("\nLoaded memory details:") for i, message in enumerate(loaded_memory.history): print(f"\nMessage {i + 1}:") print(f"Role: {message.role}") print(f"Turn ID: {message.turn_id}") print(f"Content type: {type(message.content).__name__}") print("Content:") for field, value in message.content.model_dump().items(): print(f" {field}: {value}") # Final verification print("\nFinal verification:") print(f"Max messages: {loaded_memory.max_messages}") print(f"Current turn ID: {loaded_memory.get_current_turn_id()}") print("Last message content:") last_message = loaded_memory.history[-1] print(last_message.content.model_dump())