Extending LLM Clients¶
This guide shows how to create custom LLM clients that integrate with the Arshai framework. By implementing the ILLM interface, your custom client will work seamlessly with all framework components.
Note
Use the Google Gemini client (arshai.llms.google_genai.GeminiClient) as your reference implementation. According to the codebase documentation, it serves as the canonical example of best practices for LLM client development.
Framework Architecture¶
All LLM clients in Arshai follow a standardized architecture:
- BaseLLMClient Foundation
All clients extend
BaseLLMClient, which provides common functionality for usage tracking, function orchestration, and streaming support.- ILLM Interface Contract
Clients must implement the
ILLMinterface withchat()andstream()methods.- Consistent Patterns
All clients follow identical patterns for function calling, structured output, and background tasks.
Basic Client Structure¶
Here’s the minimal structure for a custom LLM client:
from typing import Dict, Any, AsyncGenerator
from arshai.core.interfaces.illm import ILLMConfig, ILLMInput
from arshai.llms.base_llm_client import BaseLLMClient
from arshai.llms.utils.function_execution import FunctionCall, FunctionExecutionInput
class CustomLLMClient(BaseLLMClient):
"""
Custom LLM client implementation.
This demonstrates the minimal requirements for creating
a new LLM provider integration.
"""
def _initialize_client(self) -> Any:
"""Initialize the provider-specific client."""
# Your provider-specific initialization
api_key = os.environ.get("CUSTOM_API_KEY")
if not api_key:
raise ValueError("CUSTOM_API_KEY environment variable required")
# Return your provider's client instance
return YourProviderClient(api_key=api_key)
def _extract_and_standardize_usage(self, response: Any) -> Dict[str, Any]:
"""Extract usage metadata from provider response."""
# Convert provider-specific usage to standard format
return {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"thinking_tokens": 0, # If not supported
"tool_calling_tokens": 0, # If not available
"provider": self._provider_name,
"model": self.config.model,
"request_id": getattr(response, 'id', None)
}
async def _chat_simple(self, input: ILLMInput) -> Dict[str, Any]:
"""Handle simple chat without tools."""
# Implement basic chat functionality
response = await self._client.chat(
messages=[
{"role": "system", "content": input.system_prompt},
{"role": "user", "content": input.user_message}
],
model=self.config.model,
temperature=self.config.temperature
)
usage = self._extract_and_standardize_usage(response)
return {"llm_response": response.content, "usage": usage}
async def _chat_with_functions(self, input: ILLMInput) -> Dict[str, Any]:
"""Handle chat with function calling."""
# Implement function calling support
# See reference implementation for complete pattern
pass
async def _stream_simple(self, input: ILLMInput) -> AsyncGenerator[Dict[str, Any], None]:
"""Handle simple streaming."""
# Implement streaming support
# See reference implementation for complete pattern
pass
async def _stream_with_functions(self, input: ILLMInput) -> AsyncGenerator[Dict[str, Any], None]:
"""Handle streaming with function calling."""
# Implement streaming with function support
# See reference implementation for complete pattern
pass
Required Methods¶
Every LLM client must implement these core methods:
1. Client Initialization:
def _initialize_client(self) -> Any:
"""Initialize the provider's client library."""
# Environment variable validation
api_key = os.environ.get("YOUR_PROVIDER_API_KEY")
if not api_key:
raise ValueError("YOUR_PROVIDER_API_KEY environment variable required")
# Provider-specific client creation
try:
# Attempt to use safe HTTP configuration if available
from arshai.clients.utils.safe_http_client import SafeHttpClientFactory
return SafeHttpClientFactory.create_your_provider_client(api_key=api_key)
except ImportError:
# Fallback to basic client
return YourProviderClient(api_key=api_key)
2. Usage Standardization:
def _extract_and_standardize_usage(self, response: Any) -> Dict[str, Any]:
"""Convert provider usage to standard format."""
# Extract provider-specific usage data
usage = response.usage # or however your provider structures this
return {
"input_tokens": getattr(usage, 'prompt_tokens', 0),
"output_tokens": getattr(usage, 'completion_tokens', 0),
"total_tokens": getattr(usage, 'total_tokens', 0),
"thinking_tokens": getattr(usage, 'reasoning_tokens', 0), # If supported
"tool_calling_tokens": getattr(usage, 'function_tokens', 0), # If available
"provider": self._provider_name, # Set in __init__
"model": self.config.model,
"request_id": getattr(response, 'id', None)
}
3. Simple Chat Implementation:
async def _chat_simple(self, input: ILLMInput) -> Dict[str, Any]:
"""Basic chat without tools or structured output."""
# Prepare messages in your provider's format
messages = [
{"role": "system", "content": input.system_prompt},
{"role": "user", "content": input.user_message}
]
# Handle structured output if requested
kwargs = {
"messages": messages,
"model": self.config.model,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens
}
if input.structure_type:
# Add your provider's structured output configuration
kwargs["response_format"] = self._create_structure_schema(input.structure_type)
# Make API call
response = await self._client.chat.completions.create(**kwargs)
# Process response
usage = self._extract_and_standardize_usage(response)
if input.structure_type:
# Parse structured response
structured_data = self._parse_structured_response(response, input.structure_type)
return {"llm_response": structured_data, "usage": usage}
else:
# Return text response
return {"llm_response": response.choices[0].message.content, "usage": usage}
Function Calling Implementation¶
The framework provides orchestration utilities to handle function execution. Your client should focus on provider-specific function calling:
Function Declaration Conversion:
def _convert_callables_to_provider_format(self, functions: Dict[str, Callable]) -> List[Dict]:
"""Convert Python functions to provider's function format."""
provider_functions = []
for name, func in functions.items():
try:
# Use introspection to build function schema
sig = inspect.signature(func)
description = func.__doc__ or f"Execute {name} function"
# Build parameter schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name == 'self':
continue
param_type = self._python_type_to_json_schema_type(param.annotation)
properties[param_name] = {
"type": param_type,
"description": f"{param_name} parameter"
}
if param.default == inspect.Parameter.empty:
required.append(param_name)
# Create provider-specific function definition
function_def = {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
provider_functions.append(function_def)
except Exception as e:
self.logger.warning(f"Failed to convert function {name}: {e}")
continue
return provider_functions
Function Calling Orchestration:
async def _chat_with_functions(self, input: ILLMInput) -> Dict[str, Any]:
"""Handle chat with function calling."""
messages = self._create_messages(input)
# Prepare functions for your provider
provider_functions = []
all_functions = {}
if input.regular_functions:
all_functions.update(input.regular_functions)
if input.background_tasks:
all_functions.update(input.background_tasks)
if all_functions:
provider_functions = self._convert_callables_to_provider_format(all_functions)
# Multi-turn conversation loop
current_turn = 0
accumulated_usage = None
while current_turn < input.max_turns:
# Make API call with functions
response = await self._client.chat.completions.create(
messages=messages,
model=self.config.model,
temperature=self.config.temperature,
tools=provider_functions if provider_functions else None
)
# Accumulate usage
current_usage = self._extract_and_standardize_usage(response)
accumulated_usage = self._accumulate_usage_safely(current_usage, accumulated_usage)
# Check for function calls
function_calls = self._extract_function_calls(response)
if function_calls:
# Process function calls using framework orchestration
function_calls_list = self._prepare_function_calls_for_orchestrator(function_calls, input)
# Execute functions via framework orchestrator
execution_input = FunctionExecutionInput(
function_calls=function_calls_list,
available_functions=input.regular_functions or {},
available_background_tasks=input.background_tasks or {}
)
execution_result = await self._execute_functions_with_orchestrator(execution_input)
# Add results to conversation
self._add_function_results_to_messages(execution_result, messages)
# Continue if we have regular functions
if execution_result.get('regular_results'):
current_turn += 1
continue
# Return final response
return {
"llm_response": response.choices[0].message.content,
"usage": accumulated_usage
}
return {
"llm_response": "Maximum function calling turns reached",
"usage": accumulated_usage
}
Streaming Implementation¶
Implement streaming with progressive function execution:
async def _stream_with_functions(self, input: ILLMInput) -> AsyncGenerator[Dict[str, Any], None]:
"""Handle streaming with function calling."""
messages = self._create_messages(input)
provider_functions = self._prepare_functions(input)
current_turn = 0
accumulated_usage = None
while current_turn < input.max_turns:
# Create streaming request
stream = await self._client.chat.completions.create(
messages=messages,
model=self.config.model,
temperature=self.config.temperature,
tools=provider_functions,
stream=True
)
# Progressive streaming state
from arshai.llms.utils.function_execution import StreamingExecutionState
streaming_state = StreamingExecutionState()
collected_text = ""
function_calls_in_progress = {}
# Process streaming chunks
async for chunk in stream:
# Handle usage metadata
if hasattr(chunk, 'usage') and chunk.usage:
current_usage = self._extract_and_standardize_usage(chunk)
accumulated_usage = self._accumulate_usage_safely(current_usage, accumulated_usage)
# Handle text content
if chunk.choices and chunk.choices[0].delta.content:
collected_text += chunk.choices[0].delta.content
yield {"llm_response": collected_text}
# Handle function calls with progressive execution
if chunk.choices and chunk.choices[0].delta.tool_calls:
for tool_call_delta in chunk.choices[0].delta.tool_calls:
# Track function call progress
call_index = tool_call_delta.index
if call_index not in function_calls_in_progress:
function_calls_in_progress[call_index] = {
"name": "",
"arguments": ""
}
# Update function call data
if tool_call_delta.function.name:
function_calls_in_progress[call_index]["name"] = tool_call_delta.function.name
if tool_call_delta.function.arguments:
function_calls_in_progress[call_index]["arguments"] += tool_call_delta.function.arguments
# Execute when complete
if self._is_function_complete(function_calls_in_progress[call_index]):
function_call = self._create_function_call_object(
function_calls_in_progress[call_index],
call_index,
input
)
if not streaming_state.is_already_executed(function_call):
task = await self._execute_function_progressively(function_call, input)
streaming_state.add_function_task(task, function_call)
# Gather function results and continue conversation if needed
if streaming_state.active_function_tasks:
execution_result = await self._gather_progressive_results(streaming_state.active_function_tasks)
self._add_function_results_to_messages(execution_result, messages)
if execution_result.get('regular_results'):
current_turn += 1
continue
break
# Final usage yield
yield {"llm_response": None, "usage": accumulated_usage}
Helper Utilities¶
Implement these utility methods for consistency:
Type Conversion:
def _python_type_to_json_schema_type(self, python_type) -> str:
"""Convert Python types to JSON schema types."""
if python_type == str:
return "string"
elif python_type == int:
return "integer"
elif python_type == float:
return "number"
elif python_type == bool:
return "boolean"
elif python_type == list or (hasattr(python_type, '__origin__') and python_type.__origin__ == list):
return "array"
elif python_type == dict or (hasattr(python_type, '__origin__') and python_type.__origin__ == dict):
return "object"
else:
return "string"
Function Call Processing:
def _prepare_function_calls_for_orchestrator(self, function_calls, input: ILLMInput) -> List[FunctionCall]:
"""Convert provider function calls to framework objects."""
function_calls_list = []
for i, call in enumerate(function_calls):
function_name = call.function.name
try:
function_args = json.loads(call.function.arguments) if call.function.arguments else {}
except json.JSONDecodeError:
function_args = {}
call_id = f"{function_name}_{i}"
is_background = function_name in (input.background_tasks or {})
function_calls_list.append(FunctionCall(
name=function_name,
args=function_args,
call_id=call_id,
is_background=is_background
))
return function_calls_list
Testing Your Client¶
Create comprehensive tests for your client:
import pytest
from arshai.core.interfaces.illm import ILLMConfig, ILLMInput
from your_package.custom_llm_client import CustomLLMClient
@pytest.fixture
def client_config():
return ILLMConfig(
model="your-model-name",
temperature=0.7,
max_tokens=500
)
@pytest.fixture
def client(client_config):
return CustomLLMClient(client_config)
@pytest.mark.asyncio
async def test_simple_chat(client):
"""Test basic chat functionality."""
input_data = ILLMInput(
system_prompt="You are a helpful assistant",
user_message="Hello, how are you?"
)
response = await client.chat(input_data)
assert "llm_response" in response
assert "usage" in response
assert isinstance(response["llm_response"], str)
assert len(response["llm_response"]) > 0
@pytest.mark.asyncio
async def test_function_calling(client):
"""Test function calling capabilities."""
def test_function(x: int, y: int) -> int:
"""Add two numbers."""
return x + y
input_data = ILLMInput(
system_prompt="You can use tools to perform calculations",
user_message="What is 5 + 3?",
regular_functions={"test_function": test_function}
)
response = await client.chat(input_data)
assert "llm_response" in response
assert "8" in response["llm_response"] or "eight" in response["llm_response"].lower()
@pytest.mark.asyncio
async def test_streaming(client):
"""Test streaming functionality."""
input_data = ILLMInput(
system_prompt="You are a helpful assistant",
user_message="Tell me a short story"
)
chunks = []
async for chunk in client.stream(input_data):
chunks.append(chunk)
assert len(chunks) > 0
assert any(chunk.get("llm_response") for chunk in chunks)
Best Practices¶
- 1. Follow the Reference Implementation
Use the Gemini client as your template for architecture and patterns.
- 2. Implement Defensive Programming
Handle edge cases gracefully and provide meaningful error messages.
- 3. Support All Framework Features
Implement function calling, structured output, background tasks, and streaming.
- 4. Maintain Provider Abstractions
Hide provider-specific details behind the standard interface.
- 5. Test Thoroughly
Use the framework test patterns to ensure compatibility.
- 6. Document Provider-Specific Features
Clearly document any limitations or special capabilities.
- 7. Handle Rate Limiting
Implement appropriate retry logic for your provider’s rate limits.
- 8. Safe HTTP Configuration
Use the SafeHttpClientFactory when available for better reliability.
Common Patterns¶
Environment Variable Configuration:
def __init__(self, config: ILLMConfig, **kwargs):
# Provider-specific environment variables
self.api_key = os.environ.get("YOUR_PROVIDER_API_KEY")
self.base_url = os.environ.get("YOUR_PROVIDER_BASE_URL", "https://api.yourprovider.com")
if not self.api_key:
raise ValueError("YOUR_PROVIDER_API_KEY environment variable required")
super().__init__(config, **kwargs)
Error Handling:
try:
response = await self._client.chat(...)
except YourProviderRateLimitError:
raise Exception("Rate limit exceeded - implement retry logic")
except YourProviderAuthError:
raise Exception("Authentication failed - check API key")
except Exception as e:
self.logger.error(f"Provider error: {e}")
raise
Context Management:
def close(self):
"""Cleanup provider resources."""
try:
if hasattr(self._client, 'close'):
self._client.close()
except Exception as e:
self.logger.warning(f"Error closing client: {e}")
This comprehensive guide provides everything needed to create a custom LLM client that integrates seamlessly with the Arshai framework. Remember to use the Gemini client as your reference implementation for the most current patterns and best practices.