import json
import uuid
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Type
from instructor.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field
from atomic_agents.base.base_io_schema import BaseIOSchema
INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)
[docs]
class Message(BaseModel):
"""
Represents a message in the chat history.
Attributes:
role (str): The role of the message sender (e.g., 'user', 'system', 'tool').
content (BaseIOSchema): The content of the message.
turn_id (Optional[str]): Unique identifier for the turn this message belongs to.
"""
role: str
content: BaseIOSchema
turn_id: Optional[str] = None
[docs]
class ChatHistory:
"""
Manages the chat history for an AI agent.
Attributes:
history (List[Message]): A list of messages representing the chat history.
max_messages (Optional[int]): Maximum number of messages to keep in history.
current_turn_id (Optional[str]): The ID of the current turn.
"""
[docs]
def __init__(self, max_messages: Optional[int] = None):
"""
Initializes the ChatHistory with an empty history and optional constraints.
Args:
max_messages (Optional[int]): Maximum number of messages to keep in history.
When exceeded, oldest messages are removed first.
"""
self.history: List[Message] = []
self.max_messages = max_messages
self.current_turn_id: Optional[str] = None
[docs]
def initialize_turn(self) -> None:
"""
Initializes a new turn by generating a random turn ID.
"""
self.current_turn_id = str(uuid.uuid4())
[docs]
def add_message(
self,
role: str,
content: BaseIOSchema,
) -> None:
"""
Adds a message to the chat history and manages overflow.
Args:
role (str): The role of the message sender.
content (BaseIOSchema): The content of the message.
"""
if self.current_turn_id is None:
self.initialize_turn()
message = Message(
role=role,
content=content,
turn_id=self.current_turn_id,
)
self.history.append(message)
self._manage_overflow()
def _manage_overflow(self) -> None:
"""
Manages the chat history overflow based on max_messages constraint.
"""
if self.max_messages is not None:
while len(self.history) > self.max_messages:
self.history.pop(0)
[docs]
def get_history(self) -> List[Dict]:
"""
Retrieves the chat history, handling both regular and multimodal content.
Returns:
List[Dict]: The list of messages in the chat history as dictionaries.
Each dictionary has 'role' and 'content' keys, where 'content' is a list
that may contain strings (JSON) or multimodal objects.
Note:
This method does not support nested multimodal content. If your schema
contains nested objects that themselves contain multimodal content,
only the top-level multimodal content will be properly processed.
"""
history = []
for message in self.history:
input_content = message.content
processed_content = []
for field_name, field in input_content.__class__.model_fields.items():
field_value = getattr(input_content, field_name)
if isinstance(field_value, list):
has_multimodal_in_list = False
for item in field_value:
if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(item)
has_multimodal_in_list = True
if not has_multimodal_in_list:
processed_content.append(input_content.model_dump_json(include={field_name}))
else:
if isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES):
processed_content.append(field_value)
else:
processed_content.append(input_content.model_dump_json(include={field_name}))
history.append({"role": message.role, "content": processed_content})
return history
[docs]
def copy(self) -> "ChatHistory":
"""
Creates a copy of the chat history.
Returns:
ChatHistory: A copy of the chat history.
"""
new_history = ChatHistory(max_messages=self.max_messages)
new_history.load(self.dump())
new_history.current_turn_id = self.current_turn_id
return new_history
[docs]
def get_current_turn_id(self) -> Optional[str]:
"""
Returns the current turn ID.
Returns:
Optional[str]: The current turn ID, or None if not set.
"""
return self.current_turn_id
[docs]
def delete_turn_id(self, turn_id: int):
"""
Delete messages from the history by its turn ID.
Args:
turn_id (int): The turn ID of the message to delete.
Returns:
str: A success message with the deleted turn ID.
Raises:
ValueError: If the specified turn ID is not found in the history.
"""
initial_length = len(self.history)
self.history = [msg for msg in self.history if msg.turn_id != turn_id]
if len(self.history) == initial_length:
raise ValueError(f"Turn ID {turn_id} not found in history.")
# Update current_turn_id if necessary
if not len(self.history):
self.current_turn_id = None
elif turn_id == self.current_turn_id:
# Always update to the last message's turn_id
self.current_turn_id = self.history[-1].turn_id
[docs]
def get_message_count(self) -> int:
"""
Returns the number of messages in the chat history.
Returns:
int: The number of messages.
"""
return len(self.history)
[docs]
def dump(self) -> str:
"""
Serializes the entire ChatHistory instance to a JSON string.
Returns:
str: A JSON string representation of the ChatHistory.
"""
serialized_history = []
for message in self.history:
content_class = message.content.__class__
serialized_message = {
"role": message.role,
"content": {
"class_name": f"{content_class.__module__}.{content_class.__name__}",
"data": message.content.model_dump_json(),
},
"turn_id": message.turn_id,
}
serialized_history.append(serialized_message)
history_data = {
"history": serialized_history,
"max_messages": self.max_messages,
"current_turn_id": self.current_turn_id,
}
return json.dumps(history_data)
[docs]
def load(self, serialized_data: str) -> None:
"""
Deserializes a JSON string and loads it into the ChatHistory instance.
Args:
serialized_data (str): A JSON string representation of the ChatHistory.
Raises:
ValueError: If the serialized data is invalid or cannot be deserialized.
"""
try:
history_data = json.loads(serialized_data)
self.history = []
self.max_messages = history_data["max_messages"]
self.current_turn_id = history_data["current_turn_id"]
for message_data in history_data["history"]:
content_info = message_data["content"]
content_class = self._get_class_from_string(content_info["class_name"])
content_instance = content_class.model_validate_json(content_info["data"])
# Process any Image objects to convert string paths back to Path objects
self._process_multimodal_paths(content_instance)
message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"])
self.history.append(message)
except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e:
raise ValueError(f"Invalid serialized data: {e}")
@staticmethod
def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]:
"""
Retrieves a class object from its string representation.
Args:
class_string (str): The fully qualified class name.
Returns:
Type[BaseIOSchema]: The class object.
Raises:
AttributeError: If the class cannot be found.
"""
module_name, class_name = class_string.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)
def _process_multimodal_paths(self, obj):
"""
Process multimodal objects to convert string paths to Path objects.
Note: this is necessary only for PDF and Image instructor types. The from_path
behavior is slightly different for Audio as it keeps the source as a string.
Args:
obj: The object to process.
"""
if isinstance(obj, (Image, PDF)) and isinstance(obj.source, str):
# Check if the string looks like a file path (not a URL or base64 data)
if not obj.source.startswith(("http://", "https://", "data:")):
obj.source = Path(obj.source)
elif isinstance(obj, list):
# Process each item in the list
for item in obj:
self._process_multimodal_paths(item)
elif isinstance(obj, dict):
# Process each value in the dictionary
for value in obj.values():
self._process_multimodal_paths(value)
elif hasattr(obj, "model_fields"):
# Process each field of the Pydantic model
for field_name in obj.model_fields:
if hasattr(obj, field_name):
self._process_multimodal_paths(getattr(obj, field_name))
elif hasattr(obj, "__dict__") and not isinstance(obj, Enum):
# Process each attribute of the object
for attr_name, attr_value in obj.__dict__.items():
if attr_name != "__pydantic_fields_set__": # Skip pydantic internal fields
self._process_multimodal_paths(attr_value)
if __name__ == "__main__":
import instructor
from typing import List as TypeList, Dict as TypeDict
import os
# Define complex test schemas
class NestedSchema(BaseIOSchema):
"""A nested schema for testing"""
nested_field: str = Field(..., description="A nested field")
nested_int: int = Field(..., description="A nested integer")
class ComplexInputSchema(BaseIOSchema):
"""Complex Input Schema"""
text_field: str = Field(..., description="A text field")
number_field: float = Field(..., description="A number field")
list_field: TypeList[str] = Field(..., description="A list of strings")
nested_field: NestedSchema = Field(..., description="A nested schema")
class ComplexOutputSchema(BaseIOSchema):
"""Complex Output Schema"""
response_text: str = Field(..., description="A response text")
calculated_value: int = Field(..., description="A calculated value")
data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas")
# Add a new multimodal schema for testing
class MultimodalSchema(BaseIOSchema):
"""Schema for testing multimodal content"""
instruction_text: str = Field(..., description="The instruction text")
images: List[instructor.Image] = Field(..., description="The images to analyze")
# Create and populate the original history with complex data
original_history = ChatHistory(max_messages=10)
# Add a complex input message
original_history.add_message(
"user",
ComplexInputSchema(
text_field="Hello, this is a complex input",
number_field=3.14159,
list_field=["item1", "item2", "item3"],
nested_field=NestedSchema(nested_field="Nested input", nested_int=42),
),
)
# Add a complex output message
original_history.add_message(
"assistant",
ComplexOutputSchema(
response_text="This is a complex response",
calculated_value=100,
data_dict={
"key1": NestedSchema(nested_field="Nested output 1", nested_int=10),
"key2": NestedSchema(nested_field="Nested output 2", nested_int=20),
},
),
)
# Test multimodal functionality if test image exists
test_image_path = os.path.join("test_images", "test.jpg")
if os.path.exists(test_image_path):
# Add a multimodal message
original_history.add_message(
"user",
MultimodalSchema(
instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)]
),
)
# Continue with existing tests...
dumped_data = original_history.dump()
print("Dumped data:")
print(dumped_data)
# Create a new history and load the dumped data
loaded_history = ChatHistory()
loaded_history.load(dumped_data)
# Print detailed information about the loaded history
print("\nLoaded history details:")
for i, message in enumerate(loaded_history.history):
print(f"\nMessage {i + 1}:")
print(f"Role: {message.role}")
print(f"Turn ID: {message.turn_id}")
print(f"Content type: {type(message.content).__name__}")
print("Content:")
for field, value in message.content.model_dump().items():
print(f" {field}: {value}")
# Final verification
print("\nFinal verification:")
print(f"Max messages: {loaded_history.max_messages}")
print(f"Current turn ID: {loaded_history.get_current_turn_id()}")
print("Last message content:")
last_message = loaded_history.history[-1]
print(last_message.content.model_dump())