import json
import uuid
from typing import Dict, List, Optional, Type
from pathlib import Path
from instructor.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field
from atomic_agents.lib.base.base_io_schema import BaseIOSchema
INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)
[docs]
class Message(BaseModel):
"""
Represents a message in the chat history.
Attributes:
role (str): The role of the message sender (e.g., 'user', 'system', 'tool').
content (BaseIOSchema): The content of the message.
turn_id (Optional[str]): Unique identifier for the turn this message belongs to.
"""
role: str
content: BaseIOSchema
turn_id: Optional[str] = None
[docs]
class AgentMemory:
"""
Manages the chat history for an AI agent.
Attributes:
history (List[Message]): A list of messages representing the chat history.
max_messages (Optional[int]): Maximum number of messages to keep in history.
current_turn_id (Optional[str]): The ID of the current turn.
"""
[docs]
def __init__(self, max_messages: Optional[int] = None):
"""
Initializes the AgentMemory with an empty history and optional constraints.
Args:
max_messages (Optional[int]): Maximum number of messages to keep in history.
When exceeded, oldest messages are removed first.
"""
self.history: List[Message] = []
self.max_messages = max_messages
self.current_turn_id: Optional[str] = None
[docs]
def initialize_turn(self) -> None:
"""
Initializes a new turn by generating a random turn ID.
"""
self.current_turn_id = str(uuid.uuid4())
[docs]
def add_message(
self,
role: str,
content: BaseIOSchema,
) -> None:
"""
Adds a message to the chat history and manages overflow.
Args:
role (str): The role of the message sender.
content (BaseIOSchema): The content of the message.
"""
if self.current_turn_id is None:
self.initialize_turn()
message = Message(
role=role,
content=content,
turn_id=self.current_turn_id,
)
self.history.append(message)
self._manage_overflow()
def _manage_overflow(self) -> None:
"""
Manages the chat history overflow based on max_messages constraint.
"""
if self.max_messages is not None:
while len(self.history) > self.max_messages:
self.history.pop(0)
[docs]
def get_history(self) -> List[Dict]:
"""
Retrieves the chat history, handling both regular and multimodal content.
Returns:
List[Dict]: The list of messages in the chat history as dictionaries.
Each dictionary has 'role' and 'content' keys, where 'content' is a list
that may contain strings (JSON) or multimodal objects.
Note:
This method does not support nested multimodal content. If your schema
contains nested objects that themselves contain multimodal content,
only the top-level multimodal content will be properly processed.
"""
history = []
for message in self.history:
input_content = message.content
processed_content = []
for field_name, field in input_content.__class__.model_fields.items():
field_value = getattr(input_content, field_name)
if isinstance(field_value, list):
for item in field_value:
if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(item)
else:
processed_content.append(json.dumps({field_name: field_value}))
else:
if isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(field_value)
else:
processed_content.append(json.dumps({field_name: field_value}))
history.append({"role": message.role, "content": processed_content})
return history
[docs]
def copy(self) -> "AgentMemory":
"""
Creates a copy of the chat memory.
Returns:
AgentMemory: A copy of the chat memory.
"""
new_memory = AgentMemory(max_messages=self.max_messages)
new_memory.load(self.dump())
new_memory.current_turn_id = self.current_turn_id
return new_memory
[docs]
def get_current_turn_id(self) -> Optional[str]:
"""
Returns the current turn ID.
Returns:
Optional[str]: The current turn ID, or None if not set.
"""
return self.current_turn_id
[docs]
def delete_turn_id(self, turn_id: int):
"""
Delete messages from the memory by its turn ID.
Args:
turn_id (int): The turn ID of the message to delete.
Returns:
str: A success message with the deleted turn ID.
Raises:
ValueError: If the specified turn ID is not found in the memory.
"""
initial_length = len(self.history)
self.history = [msg for msg in self.history if msg.turn_id != turn_id]
if len(self.history) == initial_length:
raise ValueError(f"Turn ID {turn_id} not found in memory.")
# Update current_turn_id if necessary
if not len(self.history):
self.current_turn_id = None
elif turn_id == self.current_turn_id:
# Always update to the last message's turn_id
self.current_turn_id = self.history[-1].turn_id
[docs]
def get_message_count(self) -> int:
"""
Returns the number of messages in the chat history.
Returns:
int: The number of messages.
"""
return len(self.history)
[docs]
def dump(self) -> str:
"""
Serializes the entire AgentMemory instance to a JSON string.
Returns:
str: A JSON string representation of the AgentMemory.
"""
serialized_history = []
for message in self.history:
content_class = message.content.__class__
serialized_message = {
"role": message.role,
"content": {
"class_name": f"{content_class.__module__}.{content_class.__name__}",
"data": message.content.model_dump_json(),
},
"turn_id": message.turn_id,
}
serialized_history.append(serialized_message)
memory_data = {
"history": serialized_history,
"max_messages": self.max_messages,
"current_turn_id": self.current_turn_id,
}
return json.dumps(memory_data)
[docs]
def load(self, serialized_data: str) -> None:
"""
Deserializes a JSON string and loads it into the AgentMemory instance.
Args:
serialized_data (str): A JSON string representation of the AgentMemory.
Raises:
ValueError: If the serialized data is invalid or cannot be deserialized.
"""
try:
memory_data = json.loads(serialized_data)
self.history = []
self.max_messages = memory_data["max_messages"]
self.current_turn_id = memory_data["current_turn_id"]
for message_data in memory_data["history"]:
content_info = message_data["content"]
content_class = self._get_class_from_string(content_info["class_name"])
content_instance = content_class.model_validate_json(content_info["data"])
# Process any Image objects to convert string paths back to Path objects
self._process_multimodal_paths(content_instance)
message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"])
self.history.append(message)
except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e:
raise ValueError(f"Invalid serialized data: {e}")
@staticmethod
def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]:
"""
Retrieves a class object from its string representation.
Args:
class_string (str): The fully qualified class name.
Returns:
Type[BaseIOSchema]: The class object.
Raises:
AttributeError: If the class cannot be found.
"""
module_name, class_name = class_string.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)
def _process_multimodal_paths(self, obj):
"""
Process multimodal objects to convert string paths to Path objects.
Note: this is necessary only for PDF and Image instructor types. The from_path
behavior is slightly different for Audio as it keeps the source as a string.
Args:
obj: The object to process.
"""
if isinstance(obj, (Image, PDF)) and isinstance(obj.source, str):
# Check if the string looks like a file path (not a URL or base64 data)
if not obj.source.startswith(("http://", "https://", "data:")):
obj.source = Path(obj.source)
elif isinstance(obj, list):
# Process each item in the list
for item in obj:
self._process_multimodal_paths(item)
elif isinstance(obj, dict):
# Process each value in the dictionary
for value in obj.values():
self._process_multimodal_paths(value)
elif hasattr(obj, "model_fields"):
# Process each field of the Pydantic model
for field_name in obj.model_fields:
if hasattr(obj, field_name):
self._process_multimodal_paths(getattr(obj, field_name))
elif hasattr(obj, "__dict__"):
# Process each attribute of the object
for attr_name, attr_value in obj.__dict__.items():
if attr_name != "__pydantic_fields_set__": # Skip pydantic internal fields
self._process_multimodal_paths(attr_value)
if __name__ == "__main__":
import instructor
from typing import List as TypeList, Dict as TypeDict
import os
# Define complex test schemas
class NestedSchema(BaseIOSchema):
"""A nested schema for testing"""
nested_field: str = Field(..., description="A nested field")
nested_int: int = Field(..., description="A nested integer")
class ComplexInputSchema(BaseIOSchema):
"""Complex Input Schema"""
text_field: str = Field(..., description="A text field")
number_field: float = Field(..., description="A number field")
list_field: TypeList[str] = Field(..., description="A list of strings")
nested_field: NestedSchema = Field(..., description="A nested schema")
class ComplexOutputSchema(BaseIOSchema):
"""Complex Output Schema"""
response_text: str = Field(..., description="A response text")
calculated_value: int = Field(..., description="A calculated value")
data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas")
# Add a new multimodal schema for testing
class MultimodalSchema(BaseIOSchema):
"""Schema for testing multimodal content"""
instruction_text: str = Field(..., description="The instruction text")
images: List[instructor.Image] = Field(..., description="The images to analyze")
# Create and populate the original memory with complex data
original_memory = AgentMemory(max_messages=10)
# Add a complex input message
original_memory.add_message(
"user",
ComplexInputSchema(
text_field="Hello, this is a complex input",
number_field=3.14159,
list_field=["item1", "item2", "item3"],
nested_field=NestedSchema(nested_field="Nested input", nested_int=42),
),
)
# Add a complex output message
original_memory.add_message(
"assistant",
ComplexOutputSchema(
response_text="This is a complex response",
calculated_value=100,
data_dict={
"key1": NestedSchema(nested_field="Nested output 1", nested_int=10),
"key2": NestedSchema(nested_field="Nested output 2", nested_int=20),
},
),
)
# Test multimodal functionality if test image exists
test_image_path = os.path.join("test_images", "test.jpg")
if os.path.exists(test_image_path):
# Add a multimodal message
original_memory.add_message(
"user",
MultimodalSchema(
instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)]
),
)
# Continue with existing tests...
dumped_data = original_memory.dump()
print("Dumped data:")
print(dumped_data)
# Create a new memory and load the dumped data
loaded_memory = AgentMemory()
loaded_memory.load(dumped_data)
# Print detailed information about the loaded memory
print("\nLoaded memory details:")
for i, message in enumerate(loaded_memory.history):
print(f"\nMessage {i + 1}:")
print(f"Role: {message.role}")
print(f"Turn ID: {message.turn_id}")
print(f"Content type: {type(message.content).__name__}")
print("Content:")
for field, value in message.content.model_dump().items():
print(f" {field}: {value}")
# Final verification
print("\nFinal verification:")
print(f"Max messages: {loaded_memory.max_messages}")
print(f"Current turn ID: {loaded_memory.get_current_turn_id()}")
print("Last message content:")
last_message = loaded_memory.history[-1]
print(last_message.content.model_dump())