Module atomic_agents.agents.base_agent
Functions
async def model_from_chunks_async_patched(cls, json_chunks, **kwargs)
def model_from_chunks_patched(cls, json_chunks, **kwargs)
Classes
class BaseAgent (config: BaseAgentConfig)
-
Base class for chat agents.
This class provides the core functionality for handling chat interactions, including managing memory, generating system prompts, and obtaining responses from a language model.
Attributes
input_schema
:Type[BaseIOSchema]
- Schema for the input data.
output_schema
:Type[BaseIOSchema]
- Schema for the output data.
client
- Client for interacting with the language model.
model
:str
- The model to use for generating responses.
memory
:AgentMemory
- Memory component for storing chat history.
system_prompt_generator
:SystemPromptGenerator
- Component for generating system prompts.
initial_memory
:AgentMemory
- Initial state of the memory.
max_tokens
:int
- Maximum number of tokens allowed in the response
Initializes the BaseAgent.
Args
config
:BaseAgentConfig
- Configuration for the chat agent.
Expand source code
class BaseAgent: """ Base class for chat agents. This class provides the core functionality for handling chat interactions, including managing memory, generating system prompts, and obtaining responses from a language model. Attributes: input_schema (Type[BaseIOSchema]): Schema for the input data. output_schema (Type[BaseIOSchema]): Schema for the output data. client: Client for interacting with the language model. model (str): The model to use for generating responses. memory (AgentMemory): Memory component for storing chat history. system_prompt_generator (SystemPromptGenerator): Component for generating system prompts. initial_memory (AgentMemory): Initial state of the memory. max_tokens (int): Maximum number of tokens allowed in the response """ input_schema = BaseAgentInputSchema output_schema = BaseAgentOutputSchema def __init__(self, config: BaseAgentConfig): """ Initializes the BaseAgent. Args: config (BaseAgentConfig): Configuration for the chat agent. """ self.input_schema = config.input_schema or self.input_schema self.output_schema = config.output_schema or self.output_schema self.client = config.client self.model = config.model self.memory = config.memory or AgentMemory() self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator() self.initial_memory = self.memory.copy() self.current_user_input = None self.temperature = config.temperature self.max_tokens = config.max_tokens def reset_memory(self): """ Resets the memory to its initial state. """ self.memory = self.initial_memory.copy() def get_response(self, response_model=None) -> Type[BaseModel]: """ Obtains a response from the language model synchronously. Args: response_model (Type[BaseModel], optional): The schema for the response data. If not set, self.output_schema is used. Returns: Type[BaseModel]: The response from the language model. """ if response_model is None: response_model = self.output_schema messages = [ { "role": "system", "content": self.system_prompt_generator.generate_prompt(), } ] + self.memory.get_history() response = self.client.chat.completions.create( messages=messages, model=self.model, response_model=response_model, temperature=self.temperature, max_tokens=self.max_tokens, ) return response def run(self, user_input: Optional[BaseIOSchema] = None) -> BaseIOSchema: """ Runs the chat agent with the given user input synchronously. Args: user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory. Returns: BaseIOSchema: The response from the chat agent. """ if user_input: self.memory.initialize_turn() self.current_user_input = user_input self.memory.add_message("user", user_input) response = self.get_response(response_model=self.output_schema) self.memory.add_message("assistant", response) return response async def run_async(self, user_input: Optional[BaseIOSchema] = None): """ Runs the chat agent with the given user input, supporting streaming output asynchronously. Args: user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory. Yields: BaseModel: Partial responses from the chat agent. """ if user_input: self.memory.initialize_turn() self.current_user_input = user_input self.memory.add_message("user", user_input) messages = [ { "role": "system", "content": self.system_prompt_generator.generate_prompt(), } ] + self.memory.get_history() response_stream = self.client.chat.completions.create_partial( model=self.model, messages=messages, response_model=self.output_schema, temperature=self.temperature, max_tokens=self.max_tokens, stream=True, ) async for partial_response in response_stream: yield partial_response full_response_content = self.output_schema(**partial_response.model_dump()) self.memory.add_message("assistant", full_response_content) async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] = None): """ Deprecated method for streaming responses asynchronously. Use run_async instead. Args: user_input (Optional[Type[BaseIOSchema]]): The input from the user. If not provided, skips adding to memory. Yields: BaseModel: Partial responses from the chat agent. """ warnings.warn( "stream_response_async is deprecated and will be removed in version 1.1. Use run_async instead which can be used in the exact same way.", DeprecationWarning, stacklevel=2, ) async for response in self.run_async(user_input): yield response def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextProviderBase]: """ Retrieves a context provider by name. Args: provider_name (str): The name of the context provider. Returns: SystemPromptContextProviderBase: The context provider if found. Raises: KeyError: If the context provider is not found. """ if provider_name not in self.system_prompt_generator.context_providers: raise KeyError(f"Context provider '{provider_name}' not found.") return self.system_prompt_generator.context_providers[provider_name] def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase): """ Registers a new context provider. Args: provider_name (str): The name of the context provider. provider (SystemPromptContextProviderBase): The context provider instance. """ self.system_prompt_generator.context_providers[provider_name] = provider def unregister_context_provider(self, provider_name: str): """ Unregisters an existing context provider. Args: provider_name (str): The name of the context provider to remove. """ if provider_name in self.system_prompt_generator.context_providers: del self.system_prompt_generator.context_providers[provider_name] else: raise KeyError(f"Context provider '{provider_name}' not found.")
Class variables
var input_schema
-
This schema represents the input from the user to the AI agent.
var output_schema
-
This schema represents the response generated by the chat agent.
Methods
def get_context_provider(self, provider_name: str) ‑> Type[SystemPromptContextProviderBase]
-
Retrieves a context provider by name.
Args
provider_name
:str
- The name of the context provider.
Returns
SystemPromptContextProviderBase
- The context provider if found.
Raises
KeyError
- If the context provider is not found.
def get_response(self, response_model=None) ‑> Type[pydantic.main.BaseModel]
-
Obtains a response from the language model synchronously.
Args
response_model (Type[BaseModel], optional): The schema for the response data. If not set, self.output_schema is used.
Returns
Type[BaseModel]
- The response from the language model.
def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase)
-
Registers a new context provider.
Args
provider_name
:str
- The name of the context provider.
provider
:SystemPromptContextProviderBase
- The context provider instance.
def reset_memory(self)
-
Resets the memory to its initial state.
def run(self, user_input: Optional[BaseIOSchema] = None) ‑> BaseIOSchema
-
Runs the chat agent with the given user input synchronously.
Args
user_input
:Optional[BaseIOSchema]
- The input from the user. If not provided, skips adding to memory.
Returns
BaseIOSchema
- The response from the chat agent.
async def run_async(self, user_input: Optional[BaseIOSchema] = None)
-
Runs the chat agent with the given user input, supporting streaming output asynchronously.
Args
user_input
:Optional[BaseIOSchema]
- The input from the user. If not provided, skips adding to memory.
Yields
BaseModel
- Partial responses from the chat agent.
async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] = None)
-
Deprecated method for streaming responses asynchronously. Use run_async instead.
Args
user_input
:Optional[Type[BaseIOSchema]]
- The input from the user. If not provided, skips adding to memory.
Yields
BaseModel
- Partial responses from the chat agent.
def unregister_context_provider(self, provider_name: str)
-
Unregisters an existing context provider.
Args
provider_name
:str
- The name of the context provider to remove.
class BaseAgentConfig (**data: Any)
-
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes
__class_vars__
- The names of the class variables defined on the model.
__private_attributes__
- Metadata about the private attributes of the model.
__signature__
- The synthesized
__init__
[Signature
][inspect.Signature] of the model. __pydantic_complete__
- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
- The core schema of the model.
__pydantic_custom_init__
- Whether the model has a custom
__init__
function. __pydantic_decorators__
- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. __pydantic_generic_metadata__
- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
- The name of the post-init method for the model, if defined.
__pydantic_root_model__
- Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. __pydantic_serializer__
- The
pydantic-core
SchemaSerializer
used to dump instances of the model. __pydantic_validator__
- The
pydantic-core
SchemaValidator
used to validate instances of the model. __pydantic_fields__
- A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__
- A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__
- A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. __pydantic_fields_set__
- The names of fields explicitly set during instantiation.
__pydantic_private__
- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError
][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.self
is explicitly positional-only to allowself
as a field name.Expand source code
class BaseAgentConfig(BaseModel): client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.") model: str = Field("gpt-4o-mini", description="The model to use for generating responses.") memory: Optional[AgentMemory] = Field(None, description="Memory component for storing chat history.") system_prompt_generator: Optional[SystemPromptGenerator] = Field( None, description="Component for generating system prompts." ) input_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the input data.") output_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the output data.") model_config = {"arbitrary_types_allowed": True} temperature: Optional[float] = Field( 0, description="Temperature for response generation, typically ranging from 0 to 1.", ) max_tokens: Optional[int] = Field( None, description="Maximum number of token allowed in the response generation.", )
Ancestors
- pydantic.main.BaseModel
Class variables
var client : instructor.client.Instructor
-
The type of the None singleton.
var input_schema : Optional[Type[pydantic.main.BaseModel]]
-
The type of the None singleton.
var max_tokens : Optional[int]
-
The type of the None singleton.
var memory : Optional[AgentMemory]
-
The type of the None singleton.
var model : str
-
The type of the None singleton.
var model_config
-
The type of the None singleton.
var output_schema : Optional[Type[pydantic.main.BaseModel]]
-
The type of the None singleton.
var system_prompt_generator : Optional[SystemPromptGenerator]
-
The type of the None singleton.
var temperature : Optional[float]
-
The type of the None singleton.
class BaseAgentInputSchema (**data: Any)
-
This schema represents the input from the user to the AI agent.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError
][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.self
is explicitly positional-only to allowself
as a field name.Expand source code
class BaseAgentInputSchema(BaseIOSchema): """This schema represents the input from the user to the AI agent.""" chat_message: str = Field( ..., description="The chat message sent by the user to the assistant.", )
Ancestors
- BaseIOSchema
- pydantic.main.BaseModel
Class variables
var chat_message : str
-
The type of the None singleton.
Inherited members
class BaseAgentOutputSchema (**data: Any)
-
This schema represents the response generated by the chat agent.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError
][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.self
is explicitly positional-only to allowself
as a field name.Expand source code
class BaseAgentOutputSchema(BaseIOSchema): """This schema represents the response generated by the chat agent.""" chat_message: str = Field( ..., description=( "The chat message exchanged between the user and the chat agent. " "This contains the markdown-enabled response generated by the chat agent." ), )
Ancestors
- BaseIOSchema
- pydantic.main.BaseModel
Class variables
var chat_message : str
-
The type of the None singleton.
Inherited members