Source code for atomic_agents.agents.atomic_agent

import instructor
from pydantic import BaseModel, Field
from typing import Optional, Type, Generator, AsyncGenerator, get_args, Dict, List, Callable
import logging
from atomic_agents.context.chat_history import ChatHistory
from atomic_agents.context.system_prompt_generator import (
    BaseDynamicContextProvider,
    SystemPromptGenerator,
)
from atomic_agents.base.base_io_schema import BaseIOSchema

from instructor.dsl.partial import PartialBase
from jiter import from_json


[docs] def model_from_chunks_patched(cls, json_chunks, **kwargs): potential_object = "" partial_model = cls.get_partial_model() for chunk in json_chunks: potential_object += chunk obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings") obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj
[docs] async def model_from_chunks_async_patched(cls, json_chunks, **kwargs): potential_object = "" partial_model = cls.get_partial_model() async for chunk in json_chunks: potential_object += chunk obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings") obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj
PartialBase.model_from_chunks = classmethod(model_from_chunks_patched) PartialBase.model_from_chunks_async = classmethod(model_from_chunks_async_patched)
[docs] class BasicChatInputSchema(BaseIOSchema): """This schema represents the input from the user to the AI agent.""" chat_message: str = Field( ..., description="The chat message sent by the user to the assistant.", )
[docs] class BasicChatOutputSchema(BaseIOSchema): """This schema represents the response generated by the chat agent.""" chat_message: str = Field( ..., description=( "The chat message exchanged between the user and the chat agent. " "This contains the markdown-enabled response generated by the chat agent." ), )
[docs] class AgentConfig(BaseModel): client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.") model: str = Field(default="gpt-4o-mini", description="The model to use for generating responses.") history: Optional[ChatHistory] = Field(default=None, description="History component for storing chat history.") system_prompt_generator: Optional[SystemPromptGenerator] = Field( default=None, description="Component for generating system prompts." ) system_role: Optional[str] = Field( default="system", description="The role of the system in the conversation. None means no system prompt." ) model_config = {"arbitrary_types_allowed": True} model_api_parameters: Optional[dict] = Field(None, description="Additional parameters passed to the API provider.")
[docs] class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]: """ Base class for chat agents with full Instructor hook system integration. This class provides the core functionality for handling chat interactions, including managing history, generating system prompts, and obtaining responses from a language model. It includes comprehensive hook system support for monitoring and error handling. Type Parameters: InputSchema: Schema for the user input, must be a subclass of BaseIOSchema. OutputSchema: Schema for the agent's output, must be a subclass of BaseIOSchema. Attributes: client: Client for interacting with the language model. model (str): The model to use for generating responses. history (ChatHistory): History component for storing chat history. system_prompt_generator (SystemPromptGenerator): Component for generating system prompts. system_role (Optional[str]): The role of the system in the conversation. None means no system prompt. initial_history (ChatHistory): Initial state of the history. current_user_input (Optional[InputSchema]): The current user input being processed. model_api_parameters (dict): Additional parameters passed to the API provider. - Use this for parameters like 'temperature', 'max_tokens', etc. Hook System: The AtomicAgent integrates with Instructor's hook system to provide comprehensive monitoring and error handling capabilities. Supported events include: - 'parse:error': Triggered when Pydantic validation fails - 'completion:kwargs': Triggered before completion request - 'completion:response': Triggered after completion response - 'completion:error': Triggered on completion errors - 'completion:last_attempt': Triggered on final retry attempt Hook Methods: - register_hook(event, handler): Register a hook handler for an event - unregister_hook(event, handler): Remove a hook handler - clear_hooks(event=None): Clear hooks for specific event or all events - enable_hooks()/disable_hooks(): Control hook processing - hooks_enabled: Property to check if hooks are enabled Example: ```python # Basic usage agent = AtomicAgent[InputSchema, OutputSchema](config) # Register parse error hook for intelligent retry handling def handle_parse_error(error): print(f"Validation failed: {error}") # Implement custom retry logic, logging, etc. agent.register_hook("parse:error", handle_parse_error) # Now parse:error hooks will fire on validation failures response = agent.run(user_input) ``` """
[docs] def __init__(self, config: AgentConfig): """ Initializes the AtomicAgent. Args: config (AgentConfig): Configuration for the chat agent. """ self.client = config.client self.model = config.model self.history = config.history or ChatHistory() self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator() self.system_role = config.system_role self.initial_history = self.history.copy() self.current_user_input = None self.model_api_parameters = config.model_api_parameters or {} # Hook management attributes self._hook_handlers: Dict[str, List[Callable]] = {} self._hooks_enabled: bool = True
[docs] def reset_history(self): """ Resets the history to its initial state. """ self.history = self.initial_history.copy()
@property def input_schema(self) -> Type[BaseIOSchema]: if hasattr(self, "__orig_class__"): TI, _ = get_args(self.__orig_class__) else: TI = BasicChatInputSchema return TI @property def output_schema(self) -> Type[BaseIOSchema]: if hasattr(self, "__orig_class__"): _, TO = get_args(self.__orig_class__) else: TO = BasicChatOutputSchema return TO def _prepare_messages(self): if self.system_role is None: self.messages = [] else: self.messages = [ { "role": self.system_role, "content": self.system_prompt_generator.generate_prompt(), } ] self.messages += self.history.get_history()
[docs] def run(self, user_input: Optional[InputSchema] = None) -> OutputSchema: """ Runs the chat agent with the given user input synchronously. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Returns: OutputSchema: The response from the chat agent. """ assert not isinstance( self.client, instructor.client.AsyncInstructor ), "The run method is not supported for async clients. Use run_async instead." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response = self.client.chat.completions.create( messages=self.messages, model=self.model, response_model=self.output_schema, **self.model_api_parameters, ) self.history.add_message("assistant", response) return response
[docs] def run_stream(self, user_input: Optional[InputSchema] = None) -> Generator[OutputSchema, None, OutputSchema]: """ Runs the chat agent with the given user input, supporting streaming output. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Yields: OutputSchema: Partial responses from the chat agent. Returns: OutputSchema: The final response from the chat agent. """ assert not isinstance( self.client, instructor.client.AsyncInstructor ), "The run_stream method is not supported for async clients. Use run_async instead." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response_stream = self.client.chat.completions.create_partial( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters, stream=True, ) for partial_response in response_stream: yield partial_response full_response_content = self.output_schema(**partial_response.model_dump()) self.history.add_message("assistant", full_response_content) return full_response_content
[docs] async def run_async(self, user_input: Optional[InputSchema] = None) -> OutputSchema: """ Runs the chat agent asynchronously with the given user input. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Returns: OutputSchema: The response from the chat agent. Raises: NotAsyncIterableError: If used as an async generator (in an async for loop). Use run_async_stream() method instead for streaming responses. """ assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response = await self.client.chat.completions.create( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters ) self.history.add_message("assistant", response) return response
[docs] async def run_async_stream(self, user_input: Optional[InputSchema] = None) -> AsyncGenerator[OutputSchema, None]: """ Runs the chat agent asynchronously with the given user input, supporting streaming output. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Yields: OutputSchema: Partial responses from the chat agent. """ assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response_stream = self.client.chat.completions.create_partial( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters, stream=True, ) last_response = None async for partial_response in response_stream: last_response = partial_response yield partial_response if last_response: full_response_content = self.output_schema(**last_response.model_dump()) self.history.add_message("assistant", full_response_content)
[docs] def get_context_provider(self, provider_name: str) -> Type[BaseDynamicContextProvider]: """ Retrieves a context provider by name. Args: provider_name (str): The name of the context provider. Returns: BaseDynamicContextProvider: The context provider if found. Raises: KeyError: If the context provider is not found. """ if provider_name not in self.system_prompt_generator.context_providers: raise KeyError(f"Context provider '{provider_name}' not found.") return self.system_prompt_generator.context_providers[provider_name]
[docs] def register_context_provider(self, provider_name: str, provider: BaseDynamicContextProvider): """ Registers a new context provider. Args: provider_name (str): The name of the context provider. provider (BaseDynamicContextProvider): The context provider instance. """ self.system_prompt_generator.context_providers[provider_name] = provider
[docs] def unregister_context_provider(self, provider_name: str): """ Unregisters an existing context provider. Args: provider_name (str): The name of the context provider to remove. """ if provider_name in self.system_prompt_generator.context_providers: del self.system_prompt_generator.context_providers[provider_name] else: raise KeyError(f"Context provider '{provider_name}' not found.")
# Hook Management Methods
[docs] def register_hook(self, event: str, handler: Callable) -> None: """ Registers a hook handler for a specific event. Args: event (str): The event name (e.g., 'parse:error', 'completion:kwargs', etc.) handler (Callable): The callback function to handle the event """ if event not in self._hook_handlers: self._hook_handlers[event] = [] self._hook_handlers[event].append(handler) # Register with instructor client if it supports hooks if hasattr(self.client, "on"): self.client.on(event, handler)
[docs] def unregister_hook(self, event: str, handler: Callable) -> None: """ Unregisters a hook handler for a specific event. Args: event (str): The event name handler (Callable): The callback function to remove """ if event in self._hook_handlers and handler in self._hook_handlers[event]: self._hook_handlers[event].remove(handler) # Remove from instructor client if it supports hooks if hasattr(self.client, "off"): self.client.off(event, handler)
[docs] def clear_hooks(self, event: Optional[str] = None) -> None: """ Clears hook handlers for a specific event or all events. Args: event (Optional[str]): The event name to clear, or None to clear all """ if event: if event in self._hook_handlers: # Clear from instructor client first if hasattr(self.client, "clear"): self.client.clear(event) self._hook_handlers[event].clear() else: # Clear all hooks if hasattr(self.client, "clear"): self.client.clear() self._hook_handlers.clear()
def _dispatch_hook(self, event: str, *args, **kwargs) -> None: """ Internal method to dispatch hook events with error isolation. Args: event (str): The event name *args: Arguments to pass to handlers **kwargs: Keyword arguments to pass to handlers """ if not self._hooks_enabled or event not in self._hook_handlers: return for handler in self._hook_handlers[event]: try: handler(*args, **kwargs) except Exception as e: # Log error but don't interrupt main flow logger = logging.getLogger(__name__) logger.warning(f"Hook handler for '{event}' raised exception: {e}")
[docs] def enable_hooks(self) -> None: """Enable hook processing.""" self._hooks_enabled = True
[docs] def disable_hooks(self) -> None: """Disable hook processing.""" self._hooks_enabled = False
@property def hooks_enabled(self) -> bool: """Check if hooks are enabled.""" return self._hooks_enabled
if __name__ == "__main__": from rich.console import Console from rich.panel import Panel from rich.table import Table from rich.syntax import Syntax from rich import box from openai import OpenAI, AsyncOpenAI import instructor import asyncio from rich.live import Live import json def _create_schema_table(title: str, schema: Type[BaseModel]) -> Table: """Create a table displaying schema information. Args: title (str): Title of the table schema (Type[BaseModel]): Schema to display Returns: Table: Rich table containing schema information """ schema_table = Table(title=title, box=box.ROUNDED) schema_table.add_column("Field", style="cyan") schema_table.add_column("Type", style="magenta") schema_table.add_column("Description", style="green") for field_name, field in schema.model_fields.items(): schema_table.add_row(field_name, str(field.annotation), field.description or "") return schema_table def _create_config_table(agent: AtomicAgent) -> Table: """Create a table displaying agent configuration. Args: agent (AtomicAgent): Agent instance Returns: Table: Rich table containing configuration information """ info_table = Table(title="Agent Configuration", box=box.ROUNDED) info_table.add_column("Property", style="cyan") info_table.add_column("Value", style="yellow") info_table.add_row("Model", agent.model) info_table.add_row("History", str(type(agent.history).__name__)) info_table.add_row("System Prompt Generator", str(type(agent.system_prompt_generator).__name__)) return info_table def display_agent_info(agent: AtomicAgent): """Display information about the agent's configuration and schemas.""" console = Console() console.print( Panel.fit( "[bold blue]Agent Information[/bold blue]", border_style="blue", padding=(1, 1), ) ) # Display input schema input_schema_table = _create_schema_table("Input Schema", agent.input_schema) console.print(input_schema_table) # Display output schema output_schema_table = _create_schema_table("Output Schema", agent.output_schema) console.print(output_schema_table) # Display configuration info_table = _create_config_table(agent) console.print(info_table) # Display system prompt system_prompt = agent.system_prompt_generator.generate_prompt() console.print( Panel( Syntax(system_prompt, "markdown", theme="monokai", line_numbers=True), title="Sample System Prompt", border_style="green", expand=False, ) ) async def chat_loop(streaming: bool = False): """Interactive chat loop with the AI agent. Args: streaming (bool): Whether to use streaming mode for responses """ if streaming: client = instructor.from_openai(AsyncOpenAI()) config = AgentConfig(client=client, model="gpt-4o-mini") agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config) else: client = instructor.from_openai(OpenAI()) config = AgentConfig(client=client, model="gpt-4o-mini") agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config) # Display agent information before starting the chat display_agent_info(agent) console = Console() console.print( Panel.fit( "[bold blue]Interactive Chat Mode[/bold blue]\n" f"[cyan]Streaming: {streaming}[/cyan]\n" "Type 'exit' to quit", border_style="blue", padding=(1, 1), ) ) while True: user_message = console.input("\n[bold green]You:[/bold green] ") if user_message.lower() == "exit": console.print("[yellow]Goodbye![/yellow]") break user_input = agent.input_schema(chat_message=user_message) console.print("[bold blue]Assistant:[/bold blue]") if streaming: with Live(console=console, refresh_per_second=4) as live: # Use run_async_stream instead of run_async for streaming responses async for partial_response in agent.run_async_stream(user_input): response_json = partial_response.model_dump() json_str = json.dumps(response_json, indent=2) live.update(json_str) else: response = agent.run(user_input) response_json = response.model_dump() json_str = json.dumps(response_json, indent=2) console.print(json_str) console = Console() console.print("\n[bold]Starting chat loop...[/bold]") asyncio.run(chat_loop(streaming=True))