import json
import uuid
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Type
from instructor.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field
from atomic_agents.lib.base.base_io_schema import BaseIOSchema
INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)
[docs]
class Message(BaseModel):
"""
Represents a message in the chat history.
Attributes:
role (str): The role of the message sender (e.g., 'user', 'system', 'tool').
content (BaseIOSchema): The content of the message.
turn_id (Optional[str]): Unique identifier for the turn this message belongs to.
"""
role: str
content: BaseIOSchema
turn_id: Optional[str] = None
[docs]
class AgentMemory:
"""
Manages the chat history for an AI agent.
Attributes:
history (List[Message]): A list of messages representing the chat history.
max_messages (Optional[int]): Maximum number of messages to keep in history.
current_turn_id (Optional[str]): The ID of the current turn.
"""
[docs]
def __init__(self, max_messages: Optional[int] = None):
"""
Initializes the AgentMemory with an empty history and optional constraints.
Args:
max_messages (Optional[int]): Maximum number of messages to keep in history.
When exceeded, oldest messages are removed first.
"""
self.history: List[Message] = []
self.max_messages = max_messages
self.current_turn_id: Optional[str] = None
[docs]
def initialize_turn(self) -> None:
"""
Initializes a new turn by generating a random turn ID.
"""
self.current_turn_id = str(uuid.uuid4())
[docs]
def add_message(
self,
role: str,
content: BaseIOSchema,
) -> None:
"""
Adds a message to the chat history and manages overflow.
Args:
role (str): The role of the message sender.
content (BaseIOSchema): The content of the message.
"""
if self.current_turn_id is None:
self.initialize_turn()
message = Message(
role=role,
content=content,
turn_id=self.current_turn_id,
)
self.history.append(message)
self._manage_overflow()
def _manage_overflow(self) -> None:
"""
Manages the chat history overflow based on max_messages constraint.
"""
if self.max_messages is not None:
while len(self.history) > self.max_messages:
self.history.pop(0)
[docs]
def get_history(self) -> List[Dict]:
"""
Retrieves the chat history, handling both regular and multimodal content.
Returns:
List[Dict]: The list of messages in the chat history as dictionaries.
Each dictionary has 'role' and 'content' keys, where 'content' is a list
that may contain strings (JSON) or multimodal objects.
Note:
This method does not support nested multimodal content. If your schema
contains nested objects that themselves contain multimodal content,
only the top-level multimodal content will be properly processed.
"""
history = []
for message in self.history:
input_content = message.content
processed_content = []
for field_name, field in input_content.__class__.model_fields.items():
field_value = getattr(input_content, field_name)
if isinstance(field_value, list):
has_multimodal_in_list = False
for item in field_value:
if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(item)
has_multimodal_in_list = True
if not has_multimodal_in_list:
processed_content.append(input_content.model_dump_json(include={field_name}))
else:
if isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(field_value)
else:
processed_content.append(input_content.model_dump_json(include={field_name}))
history.append({"role": message.role, "content": processed_content})
return history
[docs]
def copy(self) -> "AgentMemory":
"""
Creates a copy of the chat memory.
Returns:
AgentMemory: A copy of the chat memory.
"""
new_memory = AgentMemory(max_messages=self.max_messages)
new_memory.load(self.dump())
new_memory.current_turn_id = self.current_turn_id
return new_memory
[docs]
def get_current_turn_id(self) -> Optional[str]:
"""
Returns the current turn ID.
Returns:
Optional[str]: The current turn ID, or None if not set.
"""
return self.current_turn_id
[docs]
def delete_turn_id(self, turn_id: int):
"""
Delete messages from the memory by its turn ID.
Args:
turn_id (int): The turn ID of the message to delete.
Returns:
str: A success message with the deleted turn ID.
Raises:
ValueError: If the specified turn ID is not found in the memory.
"""
initial_length = len(self.history)
self.history = [msg for msg in self.history if msg.turn_id != turn_id]
if len(self.history) == initial_length:
raise ValueError(f"Turn ID {turn_id} not found in memory.")
# Update current_turn_id if necessary
if not len(self.history):
self.current_turn_id = None
elif turn_id == self.current_turn_id:
# Always update to the last message's turn_id
self.current_turn_id = self.history[-1].turn_id
[docs]
def get_message_count(self) -> int:
"""
Returns the number of messages in the chat history.
Returns:
int: The number of messages.
"""
return len(self.history)
[docs]
def dump(self) -> str:
"""
Serializes the entire AgentMemory instance to a JSON string.
Returns:
str: A JSON string representation of the AgentMemory.
"""
serialized_history = []
for message in self.history:
content_class = message.content.__class__
serialized_message = {
"role": message.role,
"content": {
"class_name": f"{content_class.__module__}.{content_class.__name__}",
"data": message.content.model_dump_json(),
},
"turn_id": message.turn_id,
}
serialized_history.append(serialized_message)
memory_data = {
"history": serialized_history,
"max_messages": self.max_messages,
"current_turn_id": self.current_turn_id,
}
return json.dumps(memory_data)
[docs]
def load(self, serialized_data: str) -> None:
"""
Deserializes a JSON string and loads it into the AgentMemory instance.
Args:
serialized_data (str): A JSON string representation of the AgentMemory.
Raises:
ValueError: If the serialized data is invalid or cannot be deserialized.
"""
try:
memory_data = json.loads(serialized_data)
self.history = []
self.max_messages = memory_data["max_messages"]
self.current_turn_id = memory_data["current_turn_id"]
for message_data in memory_data["history"]:
content_info = message_data["content"]
content_class = self._get_class_from_string(content_info["class_name"])
content_instance = content_class.model_validate_json(content_info["data"])
# Process any Image objects to convert string paths back to Path objects
self._process_multimodal_paths(content_instance)
message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"])
self.history.append(message)
except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e:
raise ValueError(f"Invalid serialized data: {e}")
@staticmethod
def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]:
"""
Retrieves a class object from its string representation.
Args:
class_string (str): The fully qualified class name.
Returns:
Type[BaseIOSchema]: The class object.
Raises:
AttributeError: If the class cannot be found.
"""
module_name, class_name = class_string.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)
def _process_multimodal_paths(self, obj):
"""
Process multimodal objects to convert string paths to Path objects.
Note: this is necessary only for PDF and Image instructor types. The from_path
behavior is slightly different for Audio as it keeps the source as a string.
Args:
obj: The object to process.
"""
if isinstance(obj, (Image, PDF)) and isinstance(obj.source, str):
# Check if the string looks like a file path (not a URL or base64 data)
if not obj.source.startswith(("http://", "https://", "data:")):
obj.source = Path(obj.source)
elif isinstance(obj, list):
# Process each item in the list
for item in obj:
self._process_multimodal_paths(item)
elif isinstance(obj, dict):
# Process each value in the dictionary
for value in obj.values():
self._process_multimodal_paths(value)
elif hasattr(obj, "model_fields"):
# Process each field of the Pydantic model
for field_name in obj.model_fields:
if hasattr(obj, field_name):
self._process_multimodal_paths(getattr(obj, field_name))
elif hasattr(obj, "__dict__") and not isinstance(obj, Enum):
# Process each attribute of the object
for attr_name, attr_value in obj.__dict__.items():
if attr_name != "__pydantic_fields_set__": # Skip pydantic internal fields
self._process_multimodal_paths(attr_value)
if __name__ == "__main__":
import instructor
from typing import List as TypeList, Dict as TypeDict
import os
# Define complex test schemas
class NestedSchema(BaseIOSchema):
"""A nested schema for testing"""
nested_field: str = Field(..., description="A nested field")
nested_int: int = Field(..., description="A nested integer")
class ComplexInputSchema(BaseIOSchema):
"""Complex Input Schema"""
text_field: str = Field(..., description="A text field")
number_field: float = Field(..., description="A number field")
list_field: TypeList[str] = Field(..., description="A list of strings")
nested_field: NestedSchema = Field(..., description="A nested schema")
class ComplexOutputSchema(BaseIOSchema):
"""Complex Output Schema"""
response_text: str = Field(..., description="A response text")
calculated_value: int = Field(..., description="A calculated value")
data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas")
# Add a new multimodal schema for testing
class MultimodalSchema(BaseIOSchema):
"""Schema for testing multimodal content"""
instruction_text: str = Field(..., description="The instruction text")
images: List[instructor.Image] = Field(..., description="The images to analyze")
# Create and populate the original memory with complex data
original_memory = AgentMemory(max_messages=10)
# Add a complex input message
original_memory.add_message(
"user",
ComplexInputSchema(
text_field="Hello, this is a complex input",
number_field=3.14159,
list_field=["item1", "item2", "item3"],
nested_field=NestedSchema(nested_field="Nested input", nested_int=42),
),
)
# Add a complex output message
original_memory.add_message(
"assistant",
ComplexOutputSchema(
response_text="This is a complex response",
calculated_value=100,
data_dict={
"key1": NestedSchema(nested_field="Nested output 1", nested_int=10),
"key2": NestedSchema(nested_field="Nested output 2", nested_int=20),
},
),
)
# Test multimodal functionality if test image exists
test_image_path = os.path.join("test_images", "test.jpg")
if os.path.exists(test_image_path):
# Add a multimodal message
original_memory.add_message(
"user",
MultimodalSchema(
instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)]
),
)
# Continue with existing tests...
dumped_data = original_memory.dump()
print("Dumped data:")
print(dumped_data)
# Create a new memory and load the dumped data
loaded_memory = AgentMemory()
loaded_memory.load(dumped_data)
# Print detailed information about the loaded memory
print("\nLoaded memory details:")
for i, message in enumerate(loaded_memory.history):
print(f"\nMessage {i + 1}:")
print(f"Role: {message.role}")
print(f"Turn ID: {message.turn_id}")
print(f"Content type: {type(message.content).__name__}")
print("Content:")
for field, value in message.content.model_dump().items():
print(f" {field}: {value}")
# Final verification
print("\nFinal verification:")
print(f"Max messages: {loaded_memory.max_messages}")
print(f"Current turn ID: {loaded_memory.get_current_turn_id()}")
print("Last message content:")
last_message = loaded_memory.history[-1]
print(last_message.content.model_dump())