Module atomic_agents.agents.base_agent

Functions

async def model_from_chunks_async_patched(cls, json_chunks, **kwargs)
def model_from_chunks_patched(cls, json_chunks, **kwargs)

Classes

class BaseAgent (config: BaseAgentConfig)

Base class for chat agents.

This class provides the core functionality for handling chat interactions, including managing memory, generating system prompts, and obtaining responses from a language model.

Attributes

input_schema : Type[BaseIOSchema]
Schema for the input data.
output_schema : Type[BaseIOSchema]
Schema for the output data.
client
Client for interacting with the language model.
model : str
The model to use for generating responses.
memory : AgentMemory
Memory component for storing chat history.
system_prompt_generator : SystemPromptGenerator
Component for generating system prompts.
initial_memory : AgentMemory
Initial state of the memory.
max_tokens : int
Maximum number of tokens allowed in the response

Initializes the BaseAgent.

Args

config : BaseAgentConfig
Configuration for the chat agent.
Expand source code
class BaseAgent:
    """
    Base class for chat agents.

    This class provides the core functionality for handling chat interactions, including managing memory,
    generating system prompts, and obtaining responses from a language model.

    Attributes:
        input_schema (Type[BaseIOSchema]): Schema for the input data.
        output_schema (Type[BaseIOSchema]): Schema for the output data.
        client: Client for interacting with the language model.
        model (str): The model to use for generating responses.
        memory (AgentMemory): Memory component for storing chat history.
        system_prompt_generator (SystemPromptGenerator): Component for generating system prompts.
        initial_memory (AgentMemory): Initial state of the memory.
        max_tokens (int): Maximum number of tokens allowed in the response
    """

    input_schema = BaseAgentInputSchema
    output_schema = BaseAgentOutputSchema

    def __init__(self, config: BaseAgentConfig):
        """
        Initializes the BaseAgent.

        Args:
            config (BaseAgentConfig): Configuration for the chat agent.
        """
        self.input_schema = config.input_schema or self.input_schema
        self.output_schema = config.output_schema or self.output_schema
        self.client = config.client
        self.model = config.model
        self.memory = config.memory or AgentMemory()
        self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator()
        self.initial_memory = self.memory.copy()
        self.current_user_input = None
        self.temperature = config.temperature
        self.max_tokens = config.max_tokens

    def reset_memory(self):
        """
        Resets the memory to its initial state.
        """
        self.memory = self.initial_memory.copy()

    def get_response(self, response_model=None) -> Type[BaseModel]:
        """
        Obtains a response from the language model synchronously.

        Args:
            response_model (Type[BaseModel], optional):
                The schema for the response data. If not set, self.output_schema is used.

        Returns:
            Type[BaseModel]: The response from the language model.
        """
        if response_model is None:
            response_model = self.output_schema

        messages = [
            {
                "role": "system",
                "content": self.system_prompt_generator.generate_prompt(),
            }
        ] + self.memory.get_history()

        response = self.client.chat.completions.create(
            messages=messages,
            model=self.model,
            response_model=response_model,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        return response

    def run(self, user_input: Optional[BaseIOSchema] = None) -> BaseIOSchema:
        """
        Runs the chat agent with the given user input synchronously.

        Args:
            user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory.

        Returns:
            BaseIOSchema: The response from the chat agent.
        """
        if user_input:
            self.memory.initialize_turn()
            self.current_user_input = user_input
            self.memory.add_message("user", user_input)

        response = self.get_response(response_model=self.output_schema)
        self.memory.add_message("assistant", response)

        return response

    async def run_async(self, user_input: Optional[BaseIOSchema] = None):
        """
        Runs the chat agent with the given user input, supporting streaming output asynchronously.

        Args:
            user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory.

        Yields:
            BaseModel: Partial responses from the chat agent.
        """
        if user_input:
            self.memory.initialize_turn()
            self.current_user_input = user_input
            self.memory.add_message("user", user_input)

        messages = [
            {
                "role": "system",
                "content": self.system_prompt_generator.generate_prompt(),
            }
        ] + self.memory.get_history()

        response_stream = self.client.chat.completions.create_partial(
            model=self.model,
            messages=messages,
            response_model=self.output_schema,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            stream=True,
        )

        async for partial_response in response_stream:
            yield partial_response

        full_response_content = self.output_schema(**partial_response.model_dump())
        self.memory.add_message("assistant", full_response_content)

    async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] = None):
        """
        Deprecated method for streaming responses asynchronously. Use run_async instead.

        Args:
            user_input (Optional[Type[BaseIOSchema]]): The input from the user. If not provided, skips adding to memory.

        Yields:
            BaseModel: Partial responses from the chat agent.
        """
        warnings.warn(
            "stream_response_async is deprecated and will be removed in version 1.1. Use run_async instead which can be used in the exact same way.",
            DeprecationWarning,
            stacklevel=2,
        )
        async for response in self.run_async(user_input):
            yield response

    def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextProviderBase]:
        """
        Retrieves a context provider by name.

        Args:
            provider_name (str): The name of the context provider.

        Returns:
            SystemPromptContextProviderBase: The context provider if found.

        Raises:
            KeyError: If the context provider is not found.
        """
        if provider_name not in self.system_prompt_generator.context_providers:
            raise KeyError(f"Context provider '{provider_name}' not found.")
        return self.system_prompt_generator.context_providers[provider_name]

    def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase):
        """
        Registers a new context provider.

        Args:
            provider_name (str): The name of the context provider.
            provider (SystemPromptContextProviderBase): The context provider instance.
        """
        self.system_prompt_generator.context_providers[provider_name] = provider

    def unregister_context_provider(self, provider_name: str):
        """
        Unregisters an existing context provider.

        Args:
            provider_name (str): The name of the context provider to remove.
        """
        if provider_name in self.system_prompt_generator.context_providers:
            del self.system_prompt_generator.context_providers[provider_name]
        else:
            raise KeyError(f"Context provider '{provider_name}' not found.")

Class variables

var input_schema

This schema represents the input from the user to the AI agent.

var output_schema

This schema represents the response generated by the chat agent.

Methods

def get_context_provider(self, provider_name: str) ‑> Type[SystemPromptContextProviderBase]

Retrieves a context provider by name.

Args

provider_name : str
The name of the context provider.

Returns

SystemPromptContextProviderBase
The context provider if found.

Raises

KeyError
If the context provider is not found.
def get_response(self, response_model=None) ‑> Type[pydantic.main.BaseModel]

Obtains a response from the language model synchronously.

Args

response_model (Type[BaseModel], optional): The schema for the response data. If not set, self.output_schema is used.

Returns

Type[BaseModel]
The response from the language model.
def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase)

Registers a new context provider.

Args

provider_name : str
The name of the context provider.
provider : SystemPromptContextProviderBase
The context provider instance.
def reset_memory(self)

Resets the memory to its initial state.

def run(self, user_input: Optional[BaseIOSchema] = None) ‑> BaseIOSchema

Runs the chat agent with the given user input synchronously.

Args

user_input : Optional[BaseIOSchema]
The input from the user. If not provided, skips adding to memory.

Returns

BaseIOSchema
The response from the chat agent.
async def run_async(self, user_input: Optional[BaseIOSchema] = None)

Runs the chat agent with the given user input, supporting streaming output asynchronously.

Args

user_input : Optional[BaseIOSchema]
The input from the user. If not provided, skips adding to memory.

Yields

BaseModel
Partial responses from the chat agent.
async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] = None)

Deprecated method for streaming responses asynchronously. Use run_async instead.

Args

user_input : Optional[Type[BaseIOSchema]]
The input from the user. If not provided, skips adding to memory.

Yields

BaseModel
Partial responses from the chat agent.
def unregister_context_provider(self, provider_name: str)

Unregisters an existing context provider.

Args

provider_name : str
The name of the context provider to remove.
class BaseAgentConfig (**data: Any)

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Expand source code
class BaseAgentConfig(BaseModel):
    client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.")
    model: str = Field("gpt-4o-mini", description="The model to use for generating responses.")
    memory: Optional[AgentMemory] = Field(None, description="Memory component for storing chat history.")
    system_prompt_generator: Optional[SystemPromptGenerator] = Field(
        None, description="Component for generating system prompts."
    )
    input_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the input data.")
    output_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the output data.")
    model_config = {"arbitrary_types_allowed": True}
    temperature: Optional[float] = Field(
        0,
        description="Temperature for response generation, typically ranging from 0 to 1.",
    )
    max_tokens: Optional[int] = Field(
        None,
        description="Maximum number of token allowed in the response generation.",
    )

Ancestors

  • pydantic.main.BaseModel

Class variables

var client : instructor.client.Instructor

The type of the None singleton.

var input_schema : Optional[Type[pydantic.main.BaseModel]]

The type of the None singleton.

var max_tokens : Optional[int]

The type of the None singleton.

var memory : Optional[AgentMemory]

The type of the None singleton.

var model : str

The type of the None singleton.

var model_config

The type of the None singleton.

var output_schema : Optional[Type[pydantic.main.BaseModel]]

The type of the None singleton.

var system_prompt_generator : Optional[SystemPromptGenerator]

The type of the None singleton.

var temperature : Optional[float]

The type of the None singleton.

class BaseAgentInputSchema (**data: Any)

This schema represents the input from the user to the AI agent.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Expand source code
class BaseAgentInputSchema(BaseIOSchema):
    """This schema represents the input from the user to the AI agent."""

    chat_message: str = Field(
        ...,
        description="The chat message sent by the user to the assistant.",
    )

Ancestors

Class variables

var chat_message : str

The type of the None singleton.

Inherited members

class BaseAgentOutputSchema (**data: Any)

This schema represents the response generated by the chat agent.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Expand source code
class BaseAgentOutputSchema(BaseIOSchema):
    """This schema represents the response generated by the chat agent."""

    chat_message: str = Field(
        ...,
        description=(
            "The chat message exchanged between the user and the chat agent. "
            "This contains the markdown-enabled response generated by the chat agent."
        ),
    )

Ancestors

Class variables

var chat_message : str

The type of the None singleton.

Inherited members