Source code for atomic_agents.agents.atomic_agent

import instructor
from pydantic import BaseModel, Field
from typing import Optional, Type, Generator, AsyncGenerator, get_args
from atomic_agents.context.chat_history import ChatHistory
from atomic_agents.context.system_prompt_generator import (
    BaseDynamicContextProvider,
    SystemPromptGenerator,
)
from atomic_agents.base.base_io_schema import BaseIOSchema

from instructor.dsl.partial import PartialBase
from jiter import from_json


[docs] def model_from_chunks_patched(cls, json_chunks, **kwargs): potential_object = "" partial_model = cls.get_partial_model() for chunk in json_chunks: potential_object += chunk obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings") obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj
[docs] async def model_from_chunks_async_patched(cls, json_chunks, **kwargs): potential_object = "" partial_model = cls.get_partial_model() async for chunk in json_chunks: potential_object += chunk obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings") obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj
PartialBase.model_from_chunks = classmethod(model_from_chunks_patched) PartialBase.model_from_chunks_async = classmethod(model_from_chunks_async_patched)
[docs] class BasicChatInputSchema(BaseIOSchema): """This schema represents the input from the user to the AI agent.""" chat_message: str = Field( ..., description="The chat message sent by the user to the assistant.", )
[docs] class BasicChatOutputSchema(BaseIOSchema): """This schema represents the response generated by the chat agent.""" chat_message: str = Field( ..., description=( "The chat message exchanged between the user and the chat agent. " "This contains the markdown-enabled response generated by the chat agent." ), )
[docs] class AgentConfig(BaseModel): client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.") model: str = Field(default="gpt-4o-mini", description="The model to use for generating responses.") history: Optional[ChatHistory] = Field(default=None, description="History component for storing chat history.") system_prompt_generator: Optional[SystemPromptGenerator] = Field( default=None, description="Component for generating system prompts." ) system_role: Optional[str] = Field( default="system", description="The role of the system in the conversation. None means no system prompt." ) model_config = {"arbitrary_types_allowed": True} model_api_parameters: Optional[dict] = Field(None, description="Additional parameters passed to the API provider.")
[docs] class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]: """ Base class for chat agents. This class provides the core functionality for handling chat interactions, including managing history, generating system prompts, and obtaining responses from a language model. Type Parameters: InputSchema: Schema for the user input, must be a subclass of BaseIOSchema. OutputSchema: Schema for the agent's output, must be a subclass of BaseIOSchema. Attributes: client: Client for interacting with the language model. model (str): The model to use for generating responses. history (ChatHistory): History component for storing chat history. system_prompt_generator (SystemPromptGenerator): Component for generating system prompts. system_role (Optional[str]): The role of the system in the conversation. None means no system prompt. initial_history (ChatHistory): Initial state of the history. current_user_input (Optional[InputSchema]): The current user input being processed. model_api_parameters (dict): Additional parameters passed to the API provider. - Use this for parameters like 'temperature', 'max_tokens', etc. """
[docs] def __init__(self, config: AgentConfig): """ Initializes the AtomicAgent. Args: config (AgentConfig): Configuration for the chat agent. """ self.client = config.client self.model = config.model self.history = config.history or ChatHistory() self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator() self.system_role = config.system_role self.initial_history = self.history.copy() self.current_user_input = None self.model_api_parameters = config.model_api_parameters or {}
[docs] def reset_history(self): """ Resets the history to its initial state. """ self.history = self.initial_history.copy()
@property def input_schema(self) -> Type[BaseIOSchema]: if hasattr(self, "__orig_class__"): TI, _ = get_args(self.__orig_class__) else: TI = BasicChatInputSchema return TI @property def output_schema(self) -> Type[BaseIOSchema]: if hasattr(self, "__orig_class__"): _, TO = get_args(self.__orig_class__) else: TO = BasicChatOutputSchema return TO def _prepare_messages(self): if self.system_role is None: self.messages = [] else: self.messages = [ { "role": self.system_role, "content": self.system_prompt_generator.generate_prompt(), } ] self.messages += self.history.get_history()
[docs] def run(self, user_input: Optional[InputSchema] = None) -> OutputSchema: """ Runs the chat agent with the given user input synchronously. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Returns: OutputSchema: The response from the chat agent. """ assert not isinstance( self.client, instructor.client.AsyncInstructor ), "The run method is not supported for async clients. Use run_async instead." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response = self.client.chat.completions.create( messages=self.messages, model=self.model, response_model=self.output_schema, **self.model_api_parameters, ) self.history.add_message("assistant", response) return response
[docs] def run_stream(self, user_input: Optional[InputSchema] = None) -> Generator[OutputSchema, None, OutputSchema]: """ Runs the chat agent with the given user input, supporting streaming output. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Yields: OutputSchema: Partial responses from the chat agent. Returns: OutputSchema: The final response from the chat agent. """ assert not isinstance( self.client, instructor.client.AsyncInstructor ), "The run_stream method is not supported for async clients. Use run_async instead." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response_stream = self.client.chat.completions.create_partial( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters, stream=True, ) for partial_response in response_stream: yield partial_response full_response_content = self.output_schema(**partial_response.model_dump()) self.history.add_message("assistant", full_response_content) return full_response_content
[docs] async def run_async(self, user_input: Optional[InputSchema] = None) -> OutputSchema: """ Runs the chat agent asynchronously with the given user input. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Returns: OutputSchema: The response from the chat agent. Raises: NotAsyncIterableError: If used as an async generator (in an async for loop). Use run_async_stream() method instead for streaming responses. """ assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response = await self.client.chat.completions.create( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters ) self.history.add_message("assistant", response) return response
[docs] async def run_async_stream(self, user_input: Optional[InputSchema] = None) -> AsyncGenerator[OutputSchema, None]: """ Runs the chat agent asynchronously with the given user input, supporting streaming output. Args: user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history. Yields: OutputSchema: Partial responses from the chat agent. """ assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients." if user_input: self.history.initialize_turn() self.current_user_input = user_input self.history.add_message("user", user_input) self._prepare_messages() response_stream = self.client.chat.completions.create_partial( model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters, stream=True, ) last_response = None async for partial_response in response_stream: last_response = partial_response yield partial_response if last_response: full_response_content = self.output_schema(**last_response.model_dump()) self.history.add_message("assistant", full_response_content)
[docs] def get_context_provider(self, provider_name: str) -> Type[BaseDynamicContextProvider]: """ Retrieves a context provider by name. Args: provider_name (str): The name of the context provider. Returns: BaseDynamicContextProvider: The context provider if found. Raises: KeyError: If the context provider is not found. """ if provider_name not in self.system_prompt_generator.context_providers: raise KeyError(f"Context provider '{provider_name}' not found.") return self.system_prompt_generator.context_providers[provider_name]
[docs] def register_context_provider(self, provider_name: str, provider: BaseDynamicContextProvider): """ Registers a new context provider. Args: provider_name (str): The name of the context provider. provider (BaseDynamicContextProvider): The context provider instance. """ self.system_prompt_generator.context_providers[provider_name] = provider
[docs] def unregister_context_provider(self, provider_name: str): """ Unregisters an existing context provider. Args: provider_name (str): The name of the context provider to remove. """ if provider_name in self.system_prompt_generator.context_providers: del self.system_prompt_generator.context_providers[provider_name] else: raise KeyError(f"Context provider '{provider_name}' not found.")
if __name__ == "__main__": from rich.console import Console from rich.panel import Panel from rich.table import Table from rich.syntax import Syntax from rich import box from openai import OpenAI, AsyncOpenAI import instructor import asyncio from rich.live import Live import json def _create_schema_table(title: str, schema: Type[BaseModel]) -> Table: """Create a table displaying schema information. Args: title (str): Title of the table schema (Type[BaseModel]): Schema to display Returns: Table: Rich table containing schema information """ schema_table = Table(title=title, box=box.ROUNDED) schema_table.add_column("Field", style="cyan") schema_table.add_column("Type", style="magenta") schema_table.add_column("Description", style="green") for field_name, field in schema.model_fields.items(): schema_table.add_row(field_name, str(field.annotation), field.description or "") return schema_table def _create_config_table(agent: AtomicAgent) -> Table: """Create a table displaying agent configuration. Args: agent (AtomicAgent): Agent instance Returns: Table: Rich table containing configuration information """ info_table = Table(title="Agent Configuration", box=box.ROUNDED) info_table.add_column("Property", style="cyan") info_table.add_column("Value", style="yellow") info_table.add_row("Model", agent.model) info_table.add_row("History", str(type(agent.history).__name__)) info_table.add_row("System Prompt Generator", str(type(agent.system_prompt_generator).__name__)) return info_table def display_agent_info(agent: AtomicAgent): """Display information about the agent's configuration and schemas.""" console = Console() console.print( Panel.fit( "[bold blue]Agent Information[/bold blue]", border_style="blue", padding=(1, 1), ) ) # Display input schema input_schema_table = _create_schema_table("Input Schema", agent.input_schema) console.print(input_schema_table) # Display output schema output_schema_table = _create_schema_table("Output Schema", agent.output_schema) console.print(output_schema_table) # Display configuration info_table = _create_config_table(agent) console.print(info_table) # Display system prompt system_prompt = agent.system_prompt_generator.generate_prompt() console.print( Panel( Syntax(system_prompt, "markdown", theme="monokai", line_numbers=True), title="Sample System Prompt", border_style="green", expand=False, ) ) async def chat_loop(streaming: bool = False): """Interactive chat loop with the AI agent. Args: streaming (bool): Whether to use streaming mode for responses """ if streaming: client = instructor.from_openai(AsyncOpenAI()) config = AgentConfig(client=client, model="gpt-4o-mini") agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config) else: client = instructor.from_openai(OpenAI()) config = AgentConfig(client=client, model="gpt-4o-mini") agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config) # Display agent information before starting the chat display_agent_info(agent) console = Console() console.print( Panel.fit( "[bold blue]Interactive Chat Mode[/bold blue]\n" f"[cyan]Streaming: {streaming}[/cyan]\n" "Type 'exit' to quit", border_style="blue", padding=(1, 1), ) ) while True: user_message = console.input("\n[bold green]You:[/bold green] ") if user_message.lower() == "exit": console.print("[yellow]Goodbye![/yellow]") break user_input = agent.input_schema(chat_message=user_message) console.print("[bold blue]Assistant:[/bold blue]") if streaming: with Live(console=console, refresh_per_second=4) as live: # Use run_async_stream instead of run_async for streaming responses async for partial_response in agent.run_async_stream(user_input): response_json = partial_response.model_dump() json_str = json.dumps(response_json, indent=2) live.update(json_str) else: response = agent.run(user_input) response_json = response.model_dump() json_str = json.dumps(response_json, indent=2) console.print(json_str) console = Console() console.print("\n[bold]Starting chat loop...[/bold]") asyncio.run(chat_loop(streaming=True))