[docs]classBaseAgentInputSchema(BaseIOSchema):"""This schema represents the input from the user to the AI agent."""chat_message:str=Field(...,description="The chat message sent by the user to the assistant.",)
[docs]classBaseAgentOutputSchema(BaseIOSchema):"""This schema represents the response generated by the chat agent."""chat_message:str=Field(...,description=("The chat message exchanged between the user and the chat agent. ""This contains the markdown-enabled response generated by the chat agent."),)
[docs]classBaseAgentConfig(BaseModel):client:instructor.client.Instructor=Field(...,description="Client for interacting with the language model.")model:str=Field("gpt-4o-mini",description="The model to use for generating responses.")memory:Optional[AgentMemory]=Field(None,description="Memory component for storing chat history.")system_prompt_generator:Optional[SystemPromptGenerator]=Field(None,description="Component for generating system prompts.")input_schema:Optional[Type[BaseModel]]=Field(None,description="The schema for the input data.")output_schema:Optional[Type[BaseModel]]=Field(None,description="The schema for the output data.")model_config={"arbitrary_types_allowed":True}temperature:Optional[float]=Field(0,description="Temperature for response generation, typically ranging from 0 to 1.",)max_tokens:Optional[int]=Field(None,description="Maximum number of token allowed in the response generation.",)model_api_parameters:Optional[dict]=Field(None,description="Additional parameters passed to the API provider.")
[docs]classBaseAgent:""" Base class for chat agents. This class provides the core functionality for handling chat interactions, including managing memory, generating system prompts, and obtaining responses from a language model. Attributes: input_schema (Type[BaseIOSchema]): Schema for the input data. output_schema (Type[BaseIOSchema]): Schema for the output data. client: Client for interacting with the language model. model (str): The model to use for generating responses. memory (AgentMemory): Memory component for storing chat history. system_prompt_generator (SystemPromptGenerator): Component for generating system prompts. initial_memory (AgentMemory): Initial state of the memory. temperature (float): Temperature for response generation, typically ranging from 0 to 1. For models such as OpenAI o3-mini that do not support temperature, you must explicitly pass 'None'. DEPRECATED: Include 'temperature' in model_api_parameters instead. max_tokens (int): Maximum number of tokens allowed in the response. DEPRECATED: Include 'max_tokens' in model_api_parameters instead. model_api_parameters (dict): Additional parameters passed to the API provider. """input_schema=BaseAgentInputSchemaoutput_schema=BaseAgentOutputSchema
[docs]def__init__(self,config:BaseAgentConfig):""" Initializes the BaseAgent. Args: config (BaseAgentConfig): Configuration for the chat agent. """self.input_schema=config.input_schemaorself.input_schemaself.output_schema=config.output_schemaorself.output_schemaself.client=config.clientself.model=config.modelself.memory=config.memoryorAgentMemory()self.system_prompt_generator=config.system_prompt_generatororSystemPromptGenerator()self.initial_memory=self.memory.copy()self.current_user_input=Noneself.model_api_parameters=config.model_api_parametersor{}ifconfig.temperatureisnotNone:warnings.warn("'temperature' is deprecated and will soon be removed. Please use 'model_api_parameters' instead.",DeprecationWarning,)if"temperature"notinself.model_api_parameters:self.model_api_parameters["temperature"]=config.temperatureifconfig.max_tokensisnotNone:warnings.warn("'max_tokens' is deprecated and will soon be removed. Please use 'model_api_parameters' instead.",DeprecationWarning,)self.model_api_parameters["max_tokens"]=config.max_tokens
[docs]defreset_memory(self):""" Resets the memory to its initial state. """self.memory=self.initial_memory.copy()
[docs]defget_response(self,response_model=None)->Type[BaseModel]:""" Obtains a response from the language model synchronously. Args: response_model (Type[BaseModel], optional): The schema for the response data. If not set, self.output_schema is used. Returns: Type[BaseModel]: The response from the language model. """ifresponse_modelisNone:response_model=self.output_schemamessages=[{"role":"system","content":self.system_prompt_generator.generate_prompt(),}]+self.memory.get_history()response=self.client.chat.completions.create(messages=messages,model=self.model,response_model=response_model,**self.model_api_parameters,)returnresponse
[docs]defrun(self,user_input:Optional[BaseIOSchema]=None)->BaseIOSchema:""" Runs the chat agent with the given user input synchronously. Args: user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory. Returns: BaseIOSchema: The response from the chat agent. """ifuser_input:self.memory.initialize_turn()self.current_user_input=user_inputself.memory.add_message("user",user_input)response=self.get_response(response_model=self.output_schema)self.memory.add_message("assistant",response)returnresponse
[docs]asyncdefrun_async(self,user_input:Optional[BaseIOSchema]=None):""" Runs the chat agent with the given user input, supporting streaming output asynchronously. Args: user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory. Yields: BaseModel: Partial responses from the chat agent. """ifuser_input:self.memory.initialize_turn()self.current_user_input=user_inputself.memory.add_message("user",user_input)messages=[{"role":"system","content":self.system_prompt_generator.generate_prompt(),}]+self.memory.get_history()response_stream=self.client.chat.completions.create_partial(model=self.model,messages=messages,response_model=self.output_schema,**self.model_api_parameters,stream=True,)asyncforpartial_responseinresponse_stream:yieldpartial_responsefull_response_content=self.output_schema(**partial_response.model_dump())self.memory.add_message("assistant",full_response_content)
[docs]asyncdefstream_response_async(self,user_input:Optional[Type[BaseIOSchema]]=None):""" Deprecated method for streaming responses asynchronously. Use run_async instead. Args: user_input (Optional[Type[BaseIOSchema]]): The input from the user. If not provided, skips adding to memory. Yields: BaseModel: Partial responses from the chat agent. """warnings.warn("stream_response_async is deprecated and will be removed in version 1.1. Use run_async instead which can be used in the exact same way.",DeprecationWarning,stacklevel=2,)asyncforresponseinself.run_async(user_input):yieldresponse
[docs]defget_context_provider(self,provider_name:str)->Type[SystemPromptContextProviderBase]:""" Retrieves a context provider by name. Args: provider_name (str): The name of the context provider. Returns: SystemPromptContextProviderBase: The context provider if found. Raises: KeyError: If the context provider is not found. """ifprovider_namenotinself.system_prompt_generator.context_providers:raiseKeyError(f"Context provider '{provider_name}' not found.")returnself.system_prompt_generator.context_providers[provider_name]
[docs]defregister_context_provider(self,provider_name:str,provider:SystemPromptContextProviderBase):""" Registers a new context provider. Args: provider_name (str): The name of the context provider. provider (SystemPromptContextProviderBase): The context provider instance. """self.system_prompt_generator.context_providers[provider_name]=provider
[docs]defunregister_context_provider(self,provider_name:str):""" Unregisters an existing context provider. Args: provider_name (str): The name of the context provider to remove. """ifprovider_nameinself.system_prompt_generator.context_providers:delself.system_prompt_generator.context_providers[provider_name]else:raiseKeyError(f"Context provider '{provider_name}' not found.")
if__name__=="__main__":fromrich.consoleimportConsolefromrich.panelimportPanelfromrich.tableimportTablefromrich.syntaximportSyntaxfromrichimportboxfromopenaiimportOpenAI,AsyncOpenAIimportinstructorimportasynciofromrich.liveimportLiveimportjsondef_create_schema_table(title:str,schema:Type[BaseModel])->Table:"""Create a table displaying schema information. Args: title (str): Title of the table schema (Type[BaseModel]): Schema to display Returns: Table: Rich table containing schema information """schema_table=Table(title=title,box=box.ROUNDED)schema_table.add_column("Field",style="cyan")schema_table.add_column("Type",style="magenta")schema_table.add_column("Description",style="green")forfield_name,fieldinschema.model_fields.items():schema_table.add_row(field_name,str(field.annotation),field.descriptionor"")returnschema_tabledef_create_config_table(agent:BaseAgent)->Table:"""Create a table displaying agent configuration. Args: agent (BaseAgent): Agent instance Returns: Table: Rich table containing configuration information """info_table=Table(title="Agent Configuration",box=box.ROUNDED)info_table.add_column("Property",style="cyan")info_table.add_column("Value",style="yellow")info_table.add_row("Model",agent.model)info_table.add_row("Memory",str(type(agent.memory).__name__))info_table.add_row("System Prompt Generator",str(type(agent.system_prompt_generator).__name__))returninfo_tabledefdisplay_agent_info(agent:BaseAgent):"""Display information about the agent's configuration and schemas."""console=Console()console.print(Panel.fit("[bold blue]Agent Information[/bold blue]",border_style="blue",padding=(1,1),))# Display input schemainput_schema_table=_create_schema_table("Input Schema",agent.input_schema)console.print(input_schema_table)# Display output schemaoutput_schema_table=_create_schema_table("Output Schema",agent.output_schema)console.print(output_schema_table)# Display configurationinfo_table=_create_config_table(agent)console.print(info_table)# Display system promptsystem_prompt=agent.system_prompt_generator.generate_prompt()console.print(Panel(Syntax(system_prompt,"markdown",theme="monokai",line_numbers=True),title="Sample System Prompt",border_style="green",expand=False,))asyncdefchat_loop(streaming:bool=False):"""Interactive chat loop with the AI agent. Args: streaming (bool): Whether to use streaming mode for responses """ifstreaming:client=instructor.from_openai(AsyncOpenAI())config=BaseAgentConfig(client=client,model="gpt-4o-mini")agent=BaseAgent(config)else:client=instructor.from_openai(OpenAI())config=BaseAgentConfig(client=client,model="gpt-4o-mini")agent=BaseAgent(config)# Display agent information before starting the chatdisplay_agent_info(agent)console=Console()console.print(Panel.fit("[bold blue]Interactive Chat Mode[/bold blue]\n"f"[cyan]Streaming: {streaming}[/cyan]\n""Type 'exit' to quit",border_style="blue",padding=(1,1),))whileTrue:user_message=console.input("\n[bold green]You:[/bold green] ")ifuser_message.lower()=="exit":console.print("[yellow]Goodbye![/yellow]")breakuser_input=agent.input_schema(chat_message=user_message)console.print("[bold blue]Assistant:[/bold blue]")ifstreaming:withLive(console=console,refresh_per_second=4)aslive:asyncforpartial_responseinagent.run_async(user_input):response_json=partial_response.model_dump()json_str=json.dumps(response_json,indent=2)live.update(json_str)else:response=agent.run(user_input)response_json=response.model_dump()json_str=json.dumps(response_json,indent=2)console.print(json_str)console=Console()console.print("\n[bold]Starting chat loop...[/bold]")asyncio.run(chat_loop(streaming=True))