Source code for atomic_agents.context.chat_history

import json
import uuid
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Type

from instructor.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field

from atomic_agents.base.base_io_schema import BaseIOSchema


INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)


[docs] class Message(BaseModel): """ Represents a message in the chat history. Attributes: role (str): The role of the message sender (e.g., 'user', 'system', 'tool'). content (BaseIOSchema): The content of the message. turn_id (Optional[str]): Unique identifier for the turn this message belongs to. """ role: str content: BaseIOSchema turn_id: Optional[str] = None
[docs] class ChatHistory: """ Manages the chat history for an AI agent. Attributes: history (List[Message]): A list of messages representing the chat history. max_messages (Optional[int]): Maximum number of messages to keep in history. current_turn_id (Optional[str]): The ID of the current turn. """
[docs] def __init__(self, max_messages: Optional[int] = None): """ Initializes the ChatHistory with an empty history and optional constraints. Args: max_messages (Optional[int]): Maximum number of messages to keep in history. When exceeded, oldest messages are removed first. """ self.history: List[Message] = [] self.max_messages = max_messages self.current_turn_id: Optional[str] = None
[docs] def initialize_turn(self) -> None: """ Initializes a new turn by generating a random turn ID. """ self.current_turn_id = str(uuid.uuid4())
[docs] def add_message( self, role: str, content: BaseIOSchema, ) -> None: """ Adds a message to the chat history and manages overflow. Args: role (str): The role of the message sender. content (BaseIOSchema): The content of the message. """ if self.current_turn_id is None: self.initialize_turn() message = Message( role=role, content=content, turn_id=self.current_turn_id, ) self.history.append(message) self._manage_overflow()
def _manage_overflow(self) -> None: """ Manages the chat history overflow based on max_messages constraint. """ if self.max_messages is not None: while len(self.history) > self.max_messages: self.history.pop(0)
[docs] def get_history(self) -> List[Dict]: """ Retrieves the chat history, handling both regular and multimodal content. Returns: List[Dict]: The list of messages in the chat history as dictionaries. Each dictionary has 'role' and 'content' keys, where 'content' is a list that may contain strings (JSON) or multimodal objects. Note: This method does not support nested multimodal content. If your schema contains nested objects that themselves contain multimodal content, only the top-level multimodal content will be properly processed. """ history = [] for message in self.history: input_content = message.content processed_content = [] for field_name, field in input_content.__class__.model_fields.items(): field_value = getattr(input_content, field_name) if isinstance(field_value, list): has_multimodal_in_list = False for item in field_value: if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES): processed_content.append(item) has_multimodal_in_list = True if not has_multimodal_in_list: processed_content.append(input_content.model_dump_json(include={field_name})) else: if isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES): processed_content.append(field_value) else: processed_content.append(input_content.model_dump_json(include={field_name})) history.append({"role": message.role, "content": processed_content}) return history
[docs] def copy(self) -> "ChatHistory": """ Creates a copy of the chat history. Returns: ChatHistory: A copy of the chat history. """ new_history = ChatHistory(max_messages=self.max_messages) new_history.load(self.dump()) new_history.current_turn_id = self.current_turn_id return new_history
[docs] def get_current_turn_id(self) -> Optional[str]: """ Returns the current turn ID. Returns: Optional[str]: The current turn ID, or None if not set. """ return self.current_turn_id
[docs] def delete_turn_id(self, turn_id: int): """ Delete messages from the history by its turn ID. Args: turn_id (int): The turn ID of the message to delete. Returns: str: A success message with the deleted turn ID. Raises: ValueError: If the specified turn ID is not found in the history. """ initial_length = len(self.history) self.history = [msg for msg in self.history if msg.turn_id != turn_id] if len(self.history) == initial_length: raise ValueError(f"Turn ID {turn_id} not found in history.") # Update current_turn_id if necessary if not len(self.history): self.current_turn_id = None elif turn_id == self.current_turn_id: # Always update to the last message's turn_id self.current_turn_id = self.history[-1].turn_id
[docs] def get_message_count(self) -> int: """ Returns the number of messages in the chat history. Returns: int: The number of messages. """ return len(self.history)
[docs] def dump(self) -> str: """ Serializes the entire ChatHistory instance to a JSON string. Returns: str: A JSON string representation of the ChatHistory. """ serialized_history = [] for message in self.history: content_class = message.content.__class__ serialized_message = { "role": message.role, "content": { "class_name": f"{content_class.__module__}.{content_class.__name__}", "data": message.content.model_dump_json(), }, "turn_id": message.turn_id, } serialized_history.append(serialized_message) history_data = { "history": serialized_history, "max_messages": self.max_messages, "current_turn_id": self.current_turn_id, } return json.dumps(history_data)
[docs] def load(self, serialized_data: str) -> None: """ Deserializes a JSON string and loads it into the ChatHistory instance. Args: serialized_data (str): A JSON string representation of the ChatHistory. Raises: ValueError: If the serialized data is invalid or cannot be deserialized. """ try: history_data = json.loads(serialized_data) self.history = [] self.max_messages = history_data["max_messages"] self.current_turn_id = history_data["current_turn_id"] for message_data in history_data["history"]: content_info = message_data["content"] content_class = self._get_class_from_string(content_info["class_name"]) content_instance = content_class.model_validate_json(content_info["data"]) # Process any Image objects to convert string paths back to Path objects self._process_multimodal_paths(content_instance) message = Message(role=message_data["role"], content=content_instance, turn_id=message_data["turn_id"]) self.history.append(message) except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e: raise ValueError(f"Invalid serialized data: {e}")
@staticmethod def _get_class_from_string(class_string: str) -> Type[BaseIOSchema]: """ Retrieves a class object from its string representation. Args: class_string (str): The fully qualified class name. Returns: Type[BaseIOSchema]: The class object. Raises: AttributeError: If the class cannot be found. """ module_name, class_name = class_string.rsplit(".", 1) module = __import__(module_name, fromlist=[class_name]) return getattr(module, class_name) def _process_multimodal_paths(self, obj): """ Process multimodal objects to convert string paths to Path objects. Note: this is necessary only for PDF and Image instructor types. The from_path behavior is slightly different for Audio as it keeps the source as a string. Args: obj: The object to process. """ if isinstance(obj, (Image, PDF)) and isinstance(obj.source, str): # Check if the string looks like a file path (not a URL or base64 data) if not obj.source.startswith(("http://", "https://", "data:")): obj.source = Path(obj.source) elif isinstance(obj, list): # Process each item in the list for item in obj: self._process_multimodal_paths(item) elif isinstance(obj, dict): # Process each value in the dictionary for value in obj.values(): self._process_multimodal_paths(value) elif hasattr(obj, "model_fields"): # Process each field of the Pydantic model for field_name in obj.model_fields: if hasattr(obj, field_name): self._process_multimodal_paths(getattr(obj, field_name)) elif hasattr(obj, "__dict__") and not isinstance(obj, Enum): # Process each attribute of the object for attr_name, attr_value in obj.__dict__.items(): if attr_name != "__pydantic_fields_set__": # Skip pydantic internal fields self._process_multimodal_paths(attr_value)
if __name__ == "__main__": import instructor from typing import List as TypeList, Dict as TypeDict import os # Define complex test schemas class NestedSchema(BaseIOSchema): """A nested schema for testing""" nested_field: str = Field(..., description="A nested field") nested_int: int = Field(..., description="A nested integer") class ComplexInputSchema(BaseIOSchema): """Complex Input Schema""" text_field: str = Field(..., description="A text field") number_field: float = Field(..., description="A number field") list_field: TypeList[str] = Field(..., description="A list of strings") nested_field: NestedSchema = Field(..., description="A nested schema") class ComplexOutputSchema(BaseIOSchema): """Complex Output Schema""" response_text: str = Field(..., description="A response text") calculated_value: int = Field(..., description="A calculated value") data_dict: TypeDict[str, NestedSchema] = Field(..., description="A dictionary of nested schemas") # Add a new multimodal schema for testing class MultimodalSchema(BaseIOSchema): """Schema for testing multimodal content""" instruction_text: str = Field(..., description="The instruction text") images: List[instructor.Image] = Field(..., description="The images to analyze") # Create and populate the original history with complex data original_history = ChatHistory(max_messages=10) # Add a complex input message original_history.add_message( "user", ComplexInputSchema( text_field="Hello, this is a complex input", number_field=3.14159, list_field=["item1", "item2", "item3"], nested_field=NestedSchema(nested_field="Nested input", nested_int=42), ), ) # Add a complex output message original_history.add_message( "assistant", ComplexOutputSchema( response_text="This is a complex response", calculated_value=100, data_dict={ "key1": NestedSchema(nested_field="Nested output 1", nested_int=10), "key2": NestedSchema(nested_field="Nested output 2", nested_int=20), }, ), ) # Test multimodal functionality if test image exists test_image_path = os.path.join("test_images", "test.jpg") if os.path.exists(test_image_path): # Add a multimodal message original_history.add_message( "user", MultimodalSchema( instruction_text="Please analyze this image", images=[instructor.Image.from_path(test_image_path)] ), ) # Continue with existing tests... dumped_data = original_history.dump() print("Dumped data:") print(dumped_data) # Create a new history and load the dumped data loaded_history = ChatHistory() loaded_history.load(dumped_data) # Print detailed information about the loaded history print("\nLoaded history details:") for i, message in enumerate(loaded_history.history): print(f"\nMessage {i + 1}:") print(f"Role: {message.role}") print(f"Turn ID: {message.turn_id}") print(f"Content type: {type(message.content).__name__}") print("Content:") for field, value in message.content.model_dump().items(): print(f" {field}: {value}") # Final verification print("\nFinal verification:") print(f"Max messages: {loaded_history.max_messages}") print(f"Current turn ID: {loaded_history.get_current_turn_id()}") print("Last message content:") last_message = loaded_history.history[-1] print(last_message.content.model_dump())