import instructor
from pydantic import BaseModel, Field
from typing import Optional, Type, Generator, AsyncGenerator, get_args
from atomic_agents.context.chat_history import ChatHistory
from atomic_agents.context.system_prompt_generator import (
BaseDynamicContextProvider,
SystemPromptGenerator,
)
from atomic_agents.base.base_io_schema import BaseIOSchema
from instructor.dsl.partial import PartialBase
from jiter import from_json
[docs]
def model_from_chunks_patched(cls, json_chunks, **kwargs):
potential_object = ""
partial_model = cls.get_partial_model()
for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings")
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
[docs]
async def model_from_chunks_async_patched(cls, json_chunks, **kwargs):
potential_object = ""
partial_model = cls.get_partial_model()
async for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings")
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
PartialBase.model_from_chunks = classmethod(model_from_chunks_patched)
PartialBase.model_from_chunks_async = classmethod(model_from_chunks_async_patched)
[docs]
class BasicChatOutputSchema(BaseIOSchema):
"""This schema represents the response generated by the chat agent."""
chat_message: str = Field(
...,
description=(
"The chat message exchanged between the user and the chat agent. "
"This contains the markdown-enabled response generated by the chat agent."
),
)
[docs]
class AgentConfig(BaseModel):
client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.")
model: str = Field(default="gpt-4o-mini", description="The model to use for generating responses.")
history: Optional[ChatHistory] = Field(default=None, description="History component for storing chat history.")
system_prompt_generator: Optional[SystemPromptGenerator] = Field(
default=None, description="Component for generating system prompts."
)
system_role: Optional[str] = Field(
default="system", description="The role of the system in the conversation. None means no system prompt."
)
model_config = {"arbitrary_types_allowed": True}
model_api_parameters: Optional[dict] = Field(None, description="Additional parameters passed to the API provider.")
[docs]
class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]:
"""
Base class for chat agents.
This class provides the core functionality for handling chat interactions, including managing history,
generating system prompts, and obtaining responses from a language model.
Type Parameters:
InputSchema: Schema for the user input, must be a subclass of BaseIOSchema.
OutputSchema: Schema for the agent's output, must be a subclass of BaseIOSchema.
Attributes:
client: Client for interacting with the language model.
model (str): The model to use for generating responses.
history (ChatHistory): History component for storing chat history.
system_prompt_generator (SystemPromptGenerator): Component for generating system prompts.
system_role (Optional[str]): The role of the system in the conversation. None means no system prompt.
initial_history (ChatHistory): Initial state of the history.
current_user_input (Optional[InputSchema]): The current user input being processed.
model_api_parameters (dict): Additional parameters passed to the API provider.
- Use this for parameters like 'temperature', 'max_tokens', etc.
"""
[docs]
def __init__(self, config: AgentConfig):
"""
Initializes the AtomicAgent.
Args:
config (AgentConfig): Configuration for the chat agent.
"""
self.client = config.client
self.model = config.model
self.history = config.history or ChatHistory()
self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator()
self.system_role = config.system_role
self.initial_history = self.history.copy()
self.current_user_input = None
self.model_api_parameters = config.model_api_parameters or {}
[docs]
def reset_history(self):
"""
Resets the history to its initial state.
"""
self.history = self.initial_history.copy()
@property
def input_schema(self) -> Type[BaseIOSchema]:
if hasattr(self, "__orig_class__"):
TI, _ = get_args(self.__orig_class__)
else:
TI = BasicChatInputSchema
return TI
@property
def output_schema(self) -> Type[BaseIOSchema]:
if hasattr(self, "__orig_class__"):
_, TO = get_args(self.__orig_class__)
else:
TO = BasicChatOutputSchema
return TO
def _prepare_messages(self):
if self.system_role is None:
self.messages = []
else:
self.messages = [
{
"role": self.system_role,
"content": self.system_prompt_generator.generate_prompt(),
}
]
self.messages += self.history.get_history()
[docs]
def run(self, user_input: Optional[InputSchema] = None) -> OutputSchema:
"""
Runs the chat agent with the given user input synchronously.
Args:
user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.
Returns:
OutputSchema: The response from the chat agent.
"""
assert not isinstance(
self.client, instructor.client.AsyncInstructor
), "The run method is not supported for async clients. Use run_async instead."
if user_input:
self.history.initialize_turn()
self.current_user_input = user_input
self.history.add_message("user", user_input)
self._prepare_messages()
response = self.client.chat.completions.create(
messages=self.messages,
model=self.model,
response_model=self.output_schema,
**self.model_api_parameters,
)
self.history.add_message("assistant", response)
return response
[docs]
def run_stream(self, user_input: Optional[InputSchema] = None) -> Generator[OutputSchema, None, OutputSchema]:
"""
Runs the chat agent with the given user input, supporting streaming output.
Args:
user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.
Yields:
OutputSchema: Partial responses from the chat agent.
Returns:
OutputSchema: The final response from the chat agent.
"""
assert not isinstance(
self.client, instructor.client.AsyncInstructor
), "The run_stream method is not supported for async clients. Use run_async instead."
if user_input:
self.history.initialize_turn()
self.current_user_input = user_input
self.history.add_message("user", user_input)
self._prepare_messages()
response_stream = self.client.chat.completions.create_partial(
model=self.model,
messages=self.messages,
response_model=self.output_schema,
**self.model_api_parameters,
stream=True,
)
for partial_response in response_stream:
yield partial_response
full_response_content = self.output_schema(**partial_response.model_dump())
self.history.add_message("assistant", full_response_content)
return full_response_content
[docs]
async def run_async(self, user_input: Optional[InputSchema] = None) -> OutputSchema:
"""
Runs the chat agent asynchronously with the given user input.
Args:
user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.
Returns:
OutputSchema: The response from the chat agent.
Raises:
NotAsyncIterableError: If used as an async generator (in an async for loop).
Use run_async_stream() method instead for streaming responses.
"""
assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients."
if user_input:
self.history.initialize_turn()
self.current_user_input = user_input
self.history.add_message("user", user_input)
self._prepare_messages()
response = await self.client.chat.completions.create(
model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters
)
self.history.add_message("assistant", response)
return response
[docs]
async def run_async_stream(self, user_input: Optional[InputSchema] = None) -> AsyncGenerator[OutputSchema, None]:
"""
Runs the chat agent asynchronously with the given user input, supporting streaming output.
Args:
user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.
Yields:
OutputSchema: Partial responses from the chat agent.
"""
assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients."
if user_input:
self.history.initialize_turn()
self.current_user_input = user_input
self.history.add_message("user", user_input)
self._prepare_messages()
response_stream = self.client.chat.completions.create_partial(
model=self.model,
messages=self.messages,
response_model=self.output_schema,
**self.model_api_parameters,
stream=True,
)
last_response = None
async for partial_response in response_stream:
last_response = partial_response
yield partial_response
if last_response:
full_response_content = self.output_schema(**last_response.model_dump())
self.history.add_message("assistant", full_response_content)
[docs]
def get_context_provider(self, provider_name: str) -> Type[BaseDynamicContextProvider]:
"""
Retrieves a context provider by name.
Args:
provider_name (str): The name of the context provider.
Returns:
BaseDynamicContextProvider: The context provider if found.
Raises:
KeyError: If the context provider is not found.
"""
if provider_name not in self.system_prompt_generator.context_providers:
raise KeyError(f"Context provider '{provider_name}' not found.")
return self.system_prompt_generator.context_providers[provider_name]
[docs]
def register_context_provider(self, provider_name: str, provider: BaseDynamicContextProvider):
"""
Registers a new context provider.
Args:
provider_name (str): The name of the context provider.
provider (BaseDynamicContextProvider): The context provider instance.
"""
self.system_prompt_generator.context_providers[provider_name] = provider
[docs]
def unregister_context_provider(self, provider_name: str):
"""
Unregisters an existing context provider.
Args:
provider_name (str): The name of the context provider to remove.
"""
if provider_name in self.system_prompt_generator.context_providers:
del self.system_prompt_generator.context_providers[provider_name]
else:
raise KeyError(f"Context provider '{provider_name}' not found.")
if __name__ == "__main__":
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.syntax import Syntax
from rich import box
from openai import OpenAI, AsyncOpenAI
import instructor
import asyncio
from rich.live import Live
import json
def _create_schema_table(title: str, schema: Type[BaseModel]) -> Table:
"""Create a table displaying schema information.
Args:
title (str): Title of the table
schema (Type[BaseModel]): Schema to display
Returns:
Table: Rich table containing schema information
"""
schema_table = Table(title=title, box=box.ROUNDED)
schema_table.add_column("Field", style="cyan")
schema_table.add_column("Type", style="magenta")
schema_table.add_column("Description", style="green")
for field_name, field in schema.model_fields.items():
schema_table.add_row(field_name, str(field.annotation), field.description or "")
return schema_table
def _create_config_table(agent: AtomicAgent) -> Table:
"""Create a table displaying agent configuration.
Args:
agent (AtomicAgent): Agent instance
Returns:
Table: Rich table containing configuration information
"""
info_table = Table(title="Agent Configuration", box=box.ROUNDED)
info_table.add_column("Property", style="cyan")
info_table.add_column("Value", style="yellow")
info_table.add_row("Model", agent.model)
info_table.add_row("History", str(type(agent.history).__name__))
info_table.add_row("System Prompt Generator", str(type(agent.system_prompt_generator).__name__))
return info_table
def display_agent_info(agent: AtomicAgent):
"""Display information about the agent's configuration and schemas."""
console = Console()
console.print(
Panel.fit(
"[bold blue]Agent Information[/bold blue]",
border_style="blue",
padding=(1, 1),
)
)
# Display input schema
input_schema_table = _create_schema_table("Input Schema", agent.input_schema)
console.print(input_schema_table)
# Display output schema
output_schema_table = _create_schema_table("Output Schema", agent.output_schema)
console.print(output_schema_table)
# Display configuration
info_table = _create_config_table(agent)
console.print(info_table)
# Display system prompt
system_prompt = agent.system_prompt_generator.generate_prompt()
console.print(
Panel(
Syntax(system_prompt, "markdown", theme="monokai", line_numbers=True),
title="Sample System Prompt",
border_style="green",
expand=False,
)
)
async def chat_loop(streaming: bool = False):
"""Interactive chat loop with the AI agent.
Args:
streaming (bool): Whether to use streaming mode for responses
"""
if streaming:
client = instructor.from_openai(AsyncOpenAI())
config = AgentConfig(client=client, model="gpt-4o-mini")
agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config)
else:
client = instructor.from_openai(OpenAI())
config = AgentConfig(client=client, model="gpt-4o-mini")
agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config)
# Display agent information before starting the chat
display_agent_info(agent)
console = Console()
console.print(
Panel.fit(
"[bold blue]Interactive Chat Mode[/bold blue]\n"
f"[cyan]Streaming: {streaming}[/cyan]\n"
"Type 'exit' to quit",
border_style="blue",
padding=(1, 1),
)
)
while True:
user_message = console.input("\n[bold green]You:[/bold green] ")
if user_message.lower() == "exit":
console.print("[yellow]Goodbye![/yellow]")
break
user_input = agent.input_schema(chat_message=user_message)
console.print("[bold blue]Assistant:[/bold blue]")
if streaming:
with Live(console=console, refresh_per_second=4) as live:
# Use run_async_stream instead of run_async for streaming responses
async for partial_response in agent.run_async_stream(user_input):
response_json = partial_response.model_dump()
json_str = json.dumps(response_json, indent=2)
live.update(json_str)
else:
response = agent.run(user_input)
response_json = response.model_dump()
json_str = json.dumps(response_json, indent=2)
console.print(json_str)
console = Console()
console.print("\n[bold]Starting chat loop...[/bold]")
asyncio.run(chat_loop(streaming=True))