Skip to main content

Tutorial Progress

1

Basic LiteLLM Setup and Provider Configuration

15 minutes

2

Fallbacks, Load Balancing, and Error Handling

15 minutes

3

Cost Tracking and FastAPI Integration

15 minutes

Overall Progress 0%

Tutorial Info

Difficulty
Intermediate
Duration 45 minutes
Reading Time 23 min
Last Updated 2024-01-15
Tutorials Multi-Provider LLM Integration with LiteLLM

Multi-Provider LLM Integration with LiteLLM

Featured Intermediate

Learn how to use LiteLLM to integrate multiple LLM providers with fallbacks, cost optimization, and unified APIs

What You'll Learn

  • Set up LiteLLM with multiple providers
  • Implement fallback mechanisms for reliability
  • Track costs and optimize LLM usage
  • Build a unified API for multiple LLM providers
  • Handle rate limits and errors gracefully

Prerequisites

  • Basic Python knowledge
  • Understanding of APIs and HTTP requests
  • LLM provider API keys (OpenAI, Anthropic, etc.)
  • Python 3.8+ installed

What You'll Build

A production-ready multi-provider LLM service with cost optimization and reliability features

Overview

LiteLLM provides a unified interface for over 100 different LLM providers, making it easy to switch between models, implement fallbacks, and optimize costs. In this tutorial, you’ll learn how to build a robust multi-provider LLM service that can handle failures gracefully and optimize costs automatically.

Step 1: Basic LiteLLM Setup and Provider Configuration

Let’s start by setting up LiteLLM with multiple providers and understanding the basic usage patterns.

Installation

1
pip install litellm fastapi pydantic uvicorn python-dotenv

Basic Provider Setup

Create .env file:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# OpenAI
OPENAI_API_KEY=your-openai-api-key

# Anthropic (optional)
ANTHROPIC_API_KEY=your-anthropic-api-key

# Google (optional)
GOOGLE_API_KEY=your-google-api-key

# Azure OpenAI (optional)
AZURE_API_KEY=your-azure-api-key
AZURE_API_BASE=https://your-resource.openai.azure.com/
AZURE_API_VERSION=2023-12-01-preview

Basic Usage Examples

Create basic_usage.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import litellm
import os
from dotenv import load_dotenv
import asyncio

# Load environment variables
load_dotenv()

# Set API keys
litellm.openai_key = os.getenv("OPENAI_API_KEY")
litellm.anthropic_key = os.getenv("ANTHROPIC_API_KEY")
litellm.google_key = os.getenv("GOOGLE_API_KEY")

def test_basic_completion():
    """Test basic completion with different providers"""
    
    messages = [{"role": "user", "content": "Explain quantum computing in simple terms"}]
    
    print("=== Testing Different Providers ===\n")
    
    # OpenAI GPT-3.5
    try:
        print("1. OpenAI GPT-3.5 Turbo:")
        response = litellm.completion(
            model="gpt-3.5-turbo",
            messages=messages,
            max_tokens=100
        )
        print(f"Response: {response.choices[0].message.content}\n")
        print(f"Usage: {response.usage}\n")
    except Exception as e:
        print(f"Error with OpenAI: {e}\n")
    
    # OpenAI GPT-4
    try:
        print("2. OpenAI GPT-4:")
        response = litellm.completion(
            model="gpt-4",
            messages=messages,
            max_tokens=100
        )
        print(f"Response: {response.choices[0].message.content}\n")
        print(f"Usage: {response.usage}\n")
    except Exception as e:
        print(f"Error with GPT-4: {e}\n")
    
    # Anthropic Claude (if API key is available)
    if os.getenv("ANTHROPIC_API_KEY"):
        try:
            print("3. Anthropic Claude:")
            response = litellm.completion(
                model="claude-3-sonnet-20240229",
                messages=messages,
                max_tokens=100
            )
            print(f"Response: {response.choices[0].message.content}\n")
            print(f"Usage: {response.usage}\n")
        except Exception as e:
            print(f"Error with Claude: {e}\n")
    
    # Google Gemini (if API key is available)
    if os.getenv("GOOGLE_API_KEY"):
        try:
            print("4. Google Gemini:")
            response = litellm.completion(
                model="gemini-pro",
                messages=messages,
                max_tokens=100
            )
            print(f"Response: {response.choices[0].message.content}\n")
            print(f"Usage: {response.usage}\n")
        except Exception as e:
            print(f"Error with Gemini: {e}\n")

async def test_async_completion():
    """Test async completion for better performance"""
    
    print("=== Testing Async Completion ===\n")
    
    messages = [{"role": "user", "content": "What are the benefits of async programming?"}]
    
    try:
        response = await litellm.acompletion(
            model="gpt-3.5-turbo",
            messages=messages,
            max_tokens=150
        )
        print(f"Async Response: {response.choices[0].message.content}\n")
    except Exception as e:
        print(f"Async Error: {e}\n")

def test_streaming():
    """Test streaming responses for real-time applications"""
    
    print("=== Testing Streaming Response ===\n")
    
    messages = [{"role": "user", "content": "Write a short story about AI"}]
    
    try:
        response = litellm.completion(
            model="gpt-3.5-turbo",
            messages=messages,
            max_tokens=200,
            stream=True
        )
        
        print("Streaming response:")
        for chunk in response:
            if chunk.choices[0].delta.content:
                print(chunk.choices[0].delta.content, end="", flush=True)
        print("\n")
        
    except Exception as e:
        print(f"Streaming Error: {e}\n")

def test_cost_calculation():
    """Test cost calculation for different models"""
    
    print("=== Testing Cost Calculation ===\n")
    
    messages = [{"role": "user", "content": "Calculate the cost of this request"}]
    
    models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3-sonnet-20240229"]
    
    for model in models_to_test:
        try:
            response = litellm.completion(
                model=model,
                messages=messages,
                max_tokens=50
            )
            
            # Calculate cost
            cost = litellm.completion_cost(completion_response=response)
            
            print(f"Model: {model}")
            print(f"Tokens used: {response.usage.total_tokens}")
            print(f"Cost: ${cost:.6f}\n")
            
        except Exception as e:
            print(f"Error with {model}: {e}\n")

if __name__ == "__main__":
    test_basic_completion()
    
    # Test async (requires running in async context)
    asyncio.run(test_async_completion())
    
    test_streaming()
    test_cost_calculation()

Expected Output

Run the basic usage example:

1
python basic_usage.py

You should see output like:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
=== Testing Different Providers ===

1. OpenAI GPT-3.5 Turbo:
Response: Quantum computing is a revolutionary technology that uses the principles of quantum mechanics...

Usage: CompletionUsage(completion_tokens=95, prompt_tokens=12, total_tokens=107)

2. OpenAI GPT-4:
Response: Quantum computing harnesses quantum mechanical phenomena like superposition and entanglement...

Usage: CompletionUsage(completion_tokens=98, prompt_tokens=12, total_tokens=110)

=== Testing Async Completion ===

Async Response: Async programming allows multiple tasks to run concurrently...

=== Testing Streaming Response ===

Streaming response:
Once upon a time, in a world where artificial intelligence had become...

=== Testing Cost Calculation ===

Model: gpt-3.5-turbo
Tokens used: 107
Cost: $0.000161

Model: gpt-4
Tokens used: 110
Cost: $0.003300

Step 2: Fallbacks, Load Balancing, and Error Handling

Now let’s implement advanced features like fallbacks, load balancing, and robust error handling.

Router Configuration

Create router_config.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from litellm import Router
import os
from dotenv import load_dotenv

load_dotenv()

def create_router():
    """Create a LiteLLM router with multiple models and fallbacks"""
    
    model_list = [
        # Primary models (fast and cost-effective)
        {
            "model_name": "gpt-3.5-turbo",
            "litellm_params": {
                "model": "gpt-3.5-turbo",
                "api_key": os.getenv("OPENAI_API_KEY"),
            },
            "model_info": {
                "tier": "primary",
                "cost_per_token": 0.0000015,
                "max_tokens": 4096
            }
        },
        
        # Secondary models (higher quality, more expensive)
        {
            "model_name": "gpt-4",
            "litellm_params": {
                "model": "gpt-4",
                "api_key": os.getenv("OPENAI_API_KEY"),
            },
            "model_info": {
                "tier": "secondary",
                "cost_per_token": 0.00003,
                "max_tokens": 8192
            }
        },
        
        # Fallback models (alternative providers)
        {
            "model_name": "claude-3-sonnet",
            "litellm_params": {
                "model": "claude-3-sonnet-20240229",
                "api_key": os.getenv("ANTHROPIC_API_KEY"),
            },
            "model_info": {
                "tier": "fallback",
                "cost_per_token": 0.000015,
                "max_tokens": 200000
            }
        } if os.getenv("ANTHROPIC_API_KEY") else None,
        
        # Load balancing - multiple instances of the same model
        {
            "model_name": "gpt-3.5-turbo-backup",
            "litellm_params": {
                "model": "gpt-3.5-turbo",
                "api_key": os.getenv("OPENAI_API_KEY"),
            },
            "model_info": {
                "tier": "primary",
                "cost_per_token": 0.0000015,
                "max_tokens": 4096
            }
        }
    ]
    
    # Filter out None entries (models without API keys)
    model_list = [model for model in model_list if model is not None]
    
    # Create router with configuration
    router = Router(
        model_list=model_list,
        fallbacks=[
            {"gpt-3.5-turbo": ["gpt-3.5-turbo-backup", "claude-3-sonnet"]},
            {"gpt-4": ["gpt-3.5-turbo", "claude-3-sonnet"]}
        ],
        context_window_fallbacks=[
            {"gpt-3.5-turbo": ["gpt-4", "claude-3-sonnet"]},
        ],
        set_verbose=True,  # Enable detailed logging
        num_retries=3,     # Retry failed requests
        timeout=30,        # Request timeout in seconds
        allowed_fails=3,   # Number of allowed failures before marking model as unhealthy
        cooldown_time=60   # Cooldown time for unhealthy models
    )
    
    return router

def test_fallback_mechanism():
    """Test fallback mechanism with simulated failures"""
    
    router = create_router()
    
    print("=== Testing Fallback Mechanism ===\n")
    
    messages = [{"role": "user", "content": "Test fallback mechanism"}]
    
    # Test normal operation
    try:
        print("1. Normal operation:")
        response = router.completion(
            model="gpt-3.5-turbo",
            messages=messages,
            max_tokens=50
        )
        print(f"Success with model: {response.model}")
        print(f"Response: {response.choices[0].message.content}\n")
    except Exception as e:
        print(f"Error: {e}\n")
    
    # Test with invalid model (should fallback)
    try:
        print("2. Testing fallback with invalid primary model:")
        response = router.completion(
            model="invalid-model",
            messages=messages,
            max_tokens=50,
            fallbacks=["gpt-3.5-turbo", "claude-3-sonnet"]
        )
        print(f"Fallback successful with model: {response.model}")
        print(f"Response: {response.choices[0].message.content}\n")
    except Exception as e:
        print(f"Fallback failed: {e}\n")

def test_load_balancing():
    """Test load balancing across multiple model instances"""
    
    router = create_router()
    
    print("=== Testing Load Balancing ===\n")
    
    messages = [{"role": "user", "content": "Test load balancing"}]
    
    # Make multiple requests to see load balancing in action
    for i in range(5):
        try:
            response = router.completion(
                model="gpt-3.5-turbo",
                messages=messages,
                max_tokens=30
            )
            print(f"Request {i+1}: Model used: {response.model}")
        except Exception as e:
            print(f"Request {i+1} failed: {e}")
    
    print()

def test_context_window_fallback():
    """Test context window fallback for long inputs"""
    
    router = create_router()
    
    print("=== Testing Context Window Fallback ===\n")
    
    # Create a very long message that might exceed context window
    long_content = "This is a test message. " * 1000  # Very long message
    messages = [{"role": "user", "content": long_content}]
    
    try:
        response = router.completion(
            model="gpt-3.5-turbo",  # Smaller context window
            messages=messages,
            max_tokens=50
        )
        print(f"Success with model: {response.model}")
        print("Context window fallback worked correctly\n")
    except Exception as e:
        print(f"Context window fallback failed: {e}\n")

if __name__ == "__main__":
    test_fallback_mechanism()
    test_load_balancing()
    test_context_window_fallback()

Advanced Error Handling

Create error_handling.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import litellm
from litellm import Router
import time
import random
from typing import Dict, Any, Optional
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LLMService:
    """Advanced LLM service with comprehensive error handling"""
    
    def __init__(self):
        self.router = self._create_router()
        self.request_history = []
        self.error_counts = {}
    
    def _create_router(self):
        """Create router with error handling configuration"""
        model_list = [
            {
                "model_name": "gpt-3.5-turbo",
                "litellm_params": {
                    "model": "gpt-3.5-turbo",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                },
            },
            {
                "model_name": "gpt-4",
                "litellm_params": {
                    "model": "gpt-4",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                },
            }
        ]
        
        return Router(
            model_list=model_list,
            fallbacks=[{"gpt-3.5-turbo": ["gpt-4"]}],
            num_retries=3,
            timeout=30,
            allowed_fails=5,
            cooldown_time=120
        )
    
    def _handle_rate_limit(self, error: Exception, model: str) -> Optional[Dict[str, Any]]:
        """Handle rate limit errors with exponential backoff"""
        if "rate limit" in str(error).lower():
            wait_time = min(60, 2 ** self.error_counts.get(model, 0))
            logger.warning(f"Rate limit hit for {model}. Waiting {wait_time} seconds...")
            time.sleep(wait_time)
            
            self.error_counts[model] = self.error_counts.get(model, 0) + 1
            return {"retry": True, "wait_time": wait_time}
        
        return None
    
    def _handle_context_length_error(self, error: Exception, messages: list) -> Optional[Dict[str, Any]]:
        """Handle context length errors by truncating messages"""
        if "context length" in str(error).lower() or "token limit" in str(error).lower():
            logger.warning("Context length exceeded. Truncating messages...")
            
            # Simple truncation strategy - keep system message and last few user messages
            truncated_messages = []
            
            # Keep system message if present
            if messages and messages[0].get("role") == "system":
                truncated_messages.append(messages[0])
                remaining_messages = messages[1:]
            else:
                remaining_messages = messages
            
            # Keep last 3 messages or until we're under a reasonable token estimate
            truncated_messages.extend(remaining_messages[-3:])
            
            return {"retry": True, "messages": truncated_messages}
        
        return None
    
    def _handle_model_unavailable(self, error: Exception, model: str) -> Optional[Dict[str, Any]]:
        """Handle model unavailable errors"""
        unavailable_indicators = ["unavailable", "overloaded", "maintenance", "502", "503"]
        
        if any(indicator in str(error).lower() for indicator in unavailable_indicators):
            logger.warning(f"Model {model} is unavailable. Will try fallback...")
            return {"use_fallback": True}
        
        return None
    
    async def complete_with_retry(
        self,
        messages: list,
        model: str = "gpt-3.5-turbo",
        max_retries: int = 3,
        **kwargs
    ) -> Dict[str, Any]:
        """Complete with comprehensive error handling and retry logic"""
        
        last_error = None
        current_messages = messages.copy()
        
        for attempt in range(max_retries + 1):
            try:
                logger.info(f"Attempt {attempt + 1} for model {model}")
                
                response = await self.router.acompletion(
                    model=model,
                    messages=current_messages,
                    **kwargs
                )
                
                # Reset error count on success
                if model in self.error_counts:
                    del self.error_counts[model]
                
                # Log successful request
                self.request_history.append({
                    "timestamp": time.time(),
                    "model": response.model,
                    "success": True,
                    "attempt": attempt + 1
                })
                
                return {
                    "success": True,
                    "response": response,
                    "model_used": response.model,
                    "attempts": attempt + 1
                }
                
            except Exception as error:
                last_error = error
                logger.error(f"Attempt {attempt + 1} failed: {error}")
                
                # Try different error handling strategies
                rate_limit_result = self._handle_rate_limit(error, model)
                if rate_limit_result and rate_limit_result.get("retry"):
                    continue
                
                context_result = self._handle_context_length_error(error, current_messages)
                if context_result and context_result.get("retry"):
                    current_messages = context_result["messages"]
                    continue
                
                model_unavailable_result = self._handle_model_unavailable(error, model)
                if model_unavailable_result and model_unavailable_result.get("use_fallback"):
                    # Let the router handle fallback automatically
                    pass
                
                # If this is not the last attempt, wait before retrying
                if attempt < max_retries:
                    wait_time = min(30, 2 ** attempt + random.uniform(0, 1))
                    logger.info(f"Waiting {wait_time:.2f} seconds before retry...")
                    time.sleep(wait_time)
        
        # All attempts failed
        self.request_history.append({
            "timestamp": time.time(),
            "model": model,
            "success": False,
            "error": str(last_error),
            "attempts": max_retries + 1
        })
        
        return {
            "success": False,
            "error": str(last_error),
            "attempts": max_retries + 1
        }
    
    def get_service_health(self) -> Dict[str, Any]:
        """Get service health metrics"""
        recent_requests = [
            req for req in self.request_history 
            if time.time() - req["timestamp"] < 3600  # Last hour
        ]
        
        if not recent_requests:
            return {"status": "no_data", "requests_last_hour": 0}
        
        success_rate = sum(1 for req in recent_requests if req["success"]) / len(recent_requests)
        
        return {
            "status": "healthy" if success_rate > 0.9 else "degraded" if success_rate > 0.5 else "unhealthy",
            "success_rate": success_rate,
            "requests_last_hour": len(recent_requests),
            "error_counts": self.error_counts.copy(),
            "avg_attempts": sum(req["attempts"] for req in recent_requests) / len(recent_requests)
        }

# Test the error handling
async def test_error_handling():
    """Test the error handling service"""
    
    service = LLMService()
    
    print("=== Testing Error Handling Service ===\n")
    
    # Test normal operation
    messages = [{"role": "user", "content": "Hello, how are you?"}]
    
    result = await service.complete_with_retry(
        messages=messages,
        model="gpt-3.5-turbo",
        max_tokens=50
    )
    
    if result["success"]:
        print(f"✅ Success: {result['response'].choices[0].message.content}")
        print(f"Model used: {result['model_used']}")
        print(f"Attempts: {result['attempts']}\n")
    else:
        print(f"❌ Failed: {result['error']}")
        print(f"Attempts: {result['attempts']}\n")
    
    # Check service health
    health = service.get_service_health()
    print(f"Service Health: {health}\n")

if __name__ == "__main__":
    import asyncio
    import os
    from dotenv import load_dotenv
    
    load_dotenv()
    asyncio.run(test_error_handling())

Step 3: Cost Tracking and FastAPI Integration

Now let’s build a complete FastAPI service with cost tracking and monitoring.

Cost Tracking Service

Create cost_tracker.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import litellm
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import json
from dataclasses import dataclass, asdict
from collections import defaultdict

@dataclass
class UsageRecord:
    timestamp: datetime
    model: str
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    cost: float
    user_id: Optional[str] = None
    request_id: Optional[str] = None

class CostTracker:
    """Track and analyze LLM usage costs"""
    
    def __init__(self):
        self.usage_records: List[UsageRecord] = []
        self.daily_budgets: Dict[str, float] = {}  # user_id -> daily budget
        self.model_costs = self._load_model_costs()
    
    def _load_model_costs(self) -> Dict[str, Dict[str, float]]:
        """Load model cost information"""
        return {
            "gpt-3.5-turbo": {
                "input_cost_per_token": 0.0000015,
                "output_cost_per_token": 0.000002
            },
            "gpt-4": {
                "input_cost_per_token": 0.00003,
                "output_cost_per_token": 0.00006
            },
            "claude-3-sonnet-20240229": {
                "input_cost_per_token": 0.000015,
                "output_cost_per_token": 0.000075
            },
            "gemini-pro": {
                "input_cost_per_token": 0.000001,
                "output_cost_per_token": 0.000002
            }
        }
    
    def calculate_cost(self, model: str, prompt_tokens: int, completion_tokens: int) -> float:
        """Calculate cost for a specific request"""
        if model not in self.model_costs:
            # Fallback to LiteLLM's cost calculation
            try:
                # Create a mock response object for cost calculation
                mock_response = type('MockResponse', (), {
                    'model': model,
                    'usage': type('Usage', (), {
                        'prompt_tokens': prompt_tokens,
                        'completion_tokens': completion_tokens,
                        'total_tokens': prompt_tokens + completion_tokens
                    })()
                })()
                return litellm.completion_cost(completion_response=mock_response)
            except:
                return 0.0
        
        costs = self.model_costs[model]
        input_cost = prompt_tokens * costs["input_cost_per_token"]
        output_cost = completion_tokens * costs["output_cost_per_token"]
        return input_cost + output_cost
    
    def record_usage(
        self,
        model: str,
        prompt_tokens: int,
        completion_tokens: int,
        user_id: Optional[str] = None,
        request_id: Optional[str] = None
    ) -> UsageRecord:
        """Record usage for cost tracking"""
        
        cost = self.calculate_cost(model, prompt_tokens, completion_tokens)
        
        record = UsageRecord(
            timestamp=datetime.now(),
            model=model,
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
            cost=cost,
            user_id=user_id,
            request_id=request_id
        )
        
        self.usage_records.append(record)
        return record
    
    def get_usage_summary(
        self,
        user_id: Optional[str] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None
    ) -> Dict:
        """Get usage summary with filtering options"""
        
        # Filter records
        filtered_records = self.usage_records
        
        if user_id:
            filtered_records = [r for r in filtered_records if r.user_id == user_id]
        
        if start_date:
            filtered_records = [r for r in filtered_records if r.timestamp >= start_date]
        
        if end_date:
            filtered_records = [r for r in filtered_records if r.timestamp <= end_date]
        
        if not filtered_records:
            return {
                "total_cost": 0,
                "total_tokens": 0,
                "request_count": 0,
                "model_breakdown": {},
                "daily_breakdown": {}
            }
        
        # Calculate totals
        total_cost = sum(r.cost for r in filtered_records)
        total_tokens = sum(r.total_tokens for r in filtered_records)
        request_count = len(filtered_records)
        
        # Model breakdown
        model_breakdown = defaultdict(lambda: {"cost": 0, "tokens": 0, "requests": 0})
        for record in filtered_records:
            model_breakdown[record.model]["cost"] += record.cost
            model_breakdown[record.model]["tokens"] += record.total_tokens
            model_breakdown[record.model]["requests"] += 1
        
        # Daily breakdown
        daily_breakdown = defaultdict(lambda: {"cost": 0, "tokens": 0, "requests": 0})
        for record in filtered_records:
            date_key = record.timestamp.date().isoformat()
            daily_breakdown[date_key]["cost"] += record.cost
            daily_breakdown[date_key]["tokens"] += record.total_tokens
            daily_breakdown[date_key]["requests"] += 1
        
        return {
            "total_cost": total_cost,
            "total_tokens": total_tokens,
            "request_count": request_count,
            "average_cost_per_request": total_cost / request_count if request_count > 0 else 0,
            "model_breakdown": dict(model_breakdown),
            "daily_breakdown": dict(daily_breakdown)
        }
    
    def check_budget_limit(self, user_id: str) -> Dict[str, Any]:
        """Check if user is within budget limits"""
        if user_id not in self.daily_budgets:
            return {"within_budget": True, "message": "No budget set"}
        
        today = datetime.now().date()
        start_of_day = datetime.combine(today, datetime.min.time())
        
        today_usage = self.get_usage_summary(
            user_id=user_id,
            start_date=start_of_day
        )
        
        daily_budget = self.daily_budgets[user_id]
        today_cost = today_usage["total_cost"]
        
        within_budget = today_cost <= daily_budget
        remaining_budget = daily_budget - today_cost
        
        return {
            "within_budget": within_budget,
            "daily_budget": daily_budget,
            "today_cost": today_cost,
            "remaining_budget": remaining_budget,
            "usage_percentage": (today_cost / daily_budget) * 100 if daily_budget > 0 else 0
        }
    
    def set_daily_budget(self, user_id: str, budget: float):
        """Set daily budget for a user"""
        self.daily_budgets[user_id] = budget
    
    def get_cost_optimization_suggestions(self) -> List[Dict[str, Any]]:
        """Get suggestions for cost optimization"""
        suggestions = []
        
        if not self.usage_records:
            return suggestions
        
        # Analyze recent usage (last 7 days)
        week_ago = datetime.now() - timedelta(days=7)
        recent_records = [r for r in self.usage_records if r.timestamp >= week_ago]
        
        if not recent_records:
            return suggestions
        
        # Model usage analysis
        model_usage = defaultdict(lambda: {"cost": 0, "requests": 0})
        for record in recent_records:
            model_usage[record.model]["cost"] += record.cost
            model_usage[record.model]["requests"] += 1
        
        # Suggest cheaper alternatives for expensive models
        total_cost = sum(data["cost"] for data in model_usage.values())
        
        for model, data in model_usage.items():
            if data["cost"] / total_cost > 0.5:  # Model accounts for >50% of costs
                if model == "gpt-4":
                    suggestions.append({
                        "type": "model_substitution",
                        "message": f"Consider using gpt-3.5-turbo for simpler tasks. GPT-4 accounts for {(data['cost']/total_cost)*100:.1f}% of your costs.",
                        "potential_savings": data["cost"] * 0.9  # Rough estimate
                    })
                elif model == "claude-3-sonnet-20240229":
                    suggestions.append({
                        "type": "model_substitution",
                        "message": "Consider using gpt-3.5-turbo for cost-sensitive applications.",
                        "potential_savings": data["cost"] * 0.8
                    })
        
        # Token usage optimization
        avg_tokens = sum(r.total_tokens for r in recent_records) / len(recent_records)
        if avg_tokens > 2000:
            suggestions.append({
                "type": "token_optimization",
                "message": f"Average token usage is {avg_tokens:.0f}. Consider shortening prompts or responses.",
                "potential_savings": total_cost * 0.2
            })
        
        return suggestions

# Test cost tracking
def test_cost_tracking():
    """Test the cost tracking functionality"""
    
    tracker = CostTracker()
    
    print("=== Testing Cost Tracking ===\n")
    
    # Set budget for test user
    tracker.set_daily_budget("user123", 10.00)
    
    # Simulate some usage
    test_usage = [
        {"model": "gpt-3.5-turbo", "prompt_tokens": 100, "completion_tokens": 50, "user_id": "user123"},
        {"model": "gpt-4", "prompt_tokens": 200, "completion_tokens": 100, "user_id": "user123"},
        {"model": "gpt-3.5-turbo", "prompt_tokens": 150, "completion_tokens": 75, "user_id": "user456"},
        {"model": "claude-3-sonnet-20240229", "prompt_tokens": 300, "completion_tokens": 150, "user_id": "user123"},
    ]
    
    for usage in test_usage:
        record = tracker.record_usage(**usage)
        print(f"Recorded: {record.model} - ${record.cost:.6f}")
    
    print("\n=== Usage Summary ===")
    
    # Overall summary
    summary = tracker.get_usage_summary()
    print(f"Total cost: ${summary['total_cost']:.6f}")
    print(f"Total tokens: {summary['total_tokens']}")
    print(f"Total requests: {summary['request_count']}")
    
    print("\nModel breakdown:")
    for model, data in summary['model_breakdown'].items():
        print(f"  {model}: ${data['cost']:.6f} ({data['requests']} requests)")
    
    # User-specific summary
    print(f"\n=== User123 Summary ===")
    user_summary = tracker.get_usage_summary(user_id="user123")
    print(f"User cost: ${user_summary['total_cost']:.6f}")
    
    # Budget check
    budget_status = tracker.check_budget_limit("user123")
    print(f"Budget status: {budget_status}")
    
    # Cost optimization suggestions
    print(f"\n=== Cost Optimization Suggestions ===")
    suggestions = tracker.get_cost_optimization_suggestions()
    for suggestion in suggestions:
        print(f"- {suggestion['message']}")
        if 'potential_savings' in suggestion:
            print(f"  Potential savings: ${suggestion['potential_savings']:.6f}")

if __name__ == "__main__":
    test_cost_tracking()

FastAPI Integration

Create main.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import litellm
from litellm import Router
import os
from dotenv import load_dotenv
import uuid
from datetime import datetime
import asyncio

from cost_tracker import CostTracker
from router_config import create_router
from error_handling import LLMService

load_dotenv()

# Initialize services
app = FastAPI(
    title="Multi-Provider LLM API",
    description="A unified API for multiple LLM providers with cost tracking and fallbacks",
    version="1.0.0"
)

cost_tracker = CostTracker()
llm_service = LLMService()

# Pydantic models
class ChatMessage(BaseModel):
    role: str = Field(..., regex="^(system|user|assistant)$")
    content: str = Field(..., min_length=1)

class CompletionRequest(BaseModel):
    messages: List[ChatMessage]
    model: str = Field(default="gpt-3.5-turbo")
    max_tokens: Optional[int] = Field(default=150, ge=1, le=4000)
    temperature: Optional[float] = Field(default=0.7, ge=0, le=2)
    user_id: Optional[str] = None

class CompletionResponse(BaseModel):
    id: str
    model: str
    content: str
    usage: Dict[str, int]
    cost: float
    timestamp: datetime

class UsageSummaryResponse(BaseModel):
    total_cost: float
    total_tokens: int
    request_count: int
    average_cost_per_request: float
    model_breakdown: Dict[str, Dict[str, Any]]

class BudgetRequest(BaseModel):
    user_id: str
    daily_budget: float = Field(..., gt=0)

class BudgetStatusResponse(BaseModel):
    within_budget: bool
    daily_budget: float
    today_cost: float
    remaining_budget: float
    usage_percentage: float

# Dependency to get cost tracker
def get_cost_tracker():
    return cost_tracker

# API Endpoints
@app.post("/completion", response_model=CompletionResponse)
async def create_completion(
    request: CompletionRequest,
    background_tasks: BackgroundTasks,
    tracker: CostTracker = Depends(get_cost_tracker)
):
    """Create a completion with cost tracking and error handling"""
    
    # Check budget if user_id is provided
    if request.user_id:
        budget_status = tracker.check_budget_limit(request.user_id)
        if not budget_status["within_budget"]:
            raise HTTPException(
                status_code=429,
                detail=f"Daily budget exceeded. Used: ${budget_status['today_cost']:.4f}, Budget: ${budget_status['daily_budget']:.4f}"
            )
    
    # Convert Pydantic messages to dict format
    messages = [msg.model_dump() for msg in request.messages]
    
    # Make the completion request
    result = await llm_service.complete_with_retry(
        messages=messages,
        model=request.model,
        max_tokens=request.max_tokens,
        temperature=request.temperature
    )
    
    if not result["success"]:
        raise HTTPException(
            status_code=500,
            detail=f"Completion failed after {result['attempts']} attempts: {result['error']}"
        )
    
    response = result["response"]
    request_id = str(uuid.uuid4())
    
    # Record usage in background
    def record_usage_background():
        tracker.record_usage(
            model=response.model,
            prompt_tokens=response.usage.prompt_tokens,
            completion_tokens=response.usage.completion_tokens,
            user_id=request.user_id,
            request_id=request_id
        )
    
    background_tasks.add_task(record_usage_background)
    
    # Calculate cost
    cost = tracker.calculate_cost(
        response.model,
        response.usage.prompt_tokens,
        response.usage.completion_tokens
    )
    
    return CompletionResponse(
        id=request_id,
        model=response.model,
        content=response.choices[0].message.content,
        usage={
            "prompt_tokens": response.usage.prompt_tokens,
            "completion_tokens": response.usage.completion_tokens,
            "total_tokens": response.usage.total_tokens
        },
        cost=cost,
        timestamp=datetime.now()
    )

@app.get("/usage", response_model=UsageSummaryResponse)
async def get_usage_summary(
    user_id: Optional[str] = None,
    days: Optional[int] = Field(default=7, ge=1, le=365),
    tracker: CostTracker = Depends(get_cost_tracker)
):
    """Get usage summary with optional filtering"""
    
    start_date = None
    if days:
        from datetime import timedelta
        start_date = datetime.now() - timedelta(days=days)
    
    summary = tracker.get_usage_summary(
        user_id=user_id,
        start_date=start_date
    )
    
    return UsageSummaryResponse(**summary)

@app.post("/budget")
async def set_budget(
    request: BudgetRequest,
    tracker: CostTracker = Depends(get_cost_tracker)
):
    """Set daily budget for a user"""
    
    tracker.set_daily_budget(request.user_id, request.daily_budget)
    
    return {
        "message": f"Daily budget of ${request.daily_budget:.2f} set for user {request.user_id}"
    }

@app.get("/budget/{user_id}", response_model=BudgetStatusResponse)
async def get_budget_status(
    user_id: str,
    tracker: CostTracker = Depends(get_cost_tracker)
):
    """Get budget status for a user"""
    
    status = tracker.check_budget_limit(user_id)
    
    if not status["within_budget"] and "No budget set" not in status.get("message", ""):
        return BudgetStatusResponse(**status)
    elif "No budget set" in status.get("message", ""):
        raise HTTPException(
            status_code=404,
            detail=f"No budget set for user {user_id}"
        )
    else:
        return BudgetStatusResponse(**status)

@app.get("/optimization")
async def get_optimization_suggestions(
    tracker: CostTracker = Depends(get_cost_tracker)
):
    """Get cost optimization suggestions"""
    
    suggestions = tracker.get_cost_optimization_suggestions()
    
    return {
        "suggestions": suggestions,
        "total_potential_savings": sum(s.get("potential_savings", 0) for s in suggestions)
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    
    service_health = llm_service.get_service_health()
    
    return {
        "status": "healthy",
        "timestamp": datetime.now(),
        "llm_service": service_health,
        "total_usage_records": len(cost_tracker.usage_records)
    }

@app.get("/models")
async def list_available_models():
    """List available models and their costs"""
    
    return {
        "models": cost_tracker.model_costs,
        "message": "Costs are per token. Multiply by token count for total cost."
    }

# Test the API
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Test the Complete API

Create test_api.py:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import requests
import json
from datetime import datetime

BASE_URL = "http://localhost:8000"

def test_complete_api():
    """Test the complete LiteLLM FastAPI integration"""
    
    print("=== Testing Complete LiteLLM API ===\n")
    
    # 1. Test completion endpoint
    print("1. Testing completion endpoint...")
    completion_data = {
        "messages": [
            {"role": "user", "content": "Explain the benefits of using multiple LLM providers"}
        ],
        "model": "gpt-3.5-turbo",
        "max_tokens": 100,
        "temperature": 0.7,
        "user_id": "test_user"
    }
    
    response = requests.post(f"{BASE_URL}/completion", json=completion_data)
    if response.status_code == 200:
        result = response.json()
        print(f"✅ Completion successful:")
        print(f"   Model: {result['model']}")
        print(f"   Cost: ${result['cost']:.6f}")
        print(f"   Content: {result['content'][:100]}...")
        print(f"   Usage: {result['usage']}\n")
    else:
        print(f"❌ Completion failed: {response.text}\n")
    
    # 2. Set budget
    print("2. Setting user budget...")
    budget_data = {
        "user_id": "test_user",
        "daily_budget": 5.00
    }
    
    response = requests.post(f"{BASE_URL}/budget", json=budget_data)
    if response.status_code == 200:
        print(f"✅ Budget set: {response.json()['message']}\n")
    else:
        print(f"❌ Budget setting failed: {response.text}\n")
    
    # 3. Check budget status
    print("3. Checking budget status...")
    response = requests.get(f"{BASE_URL}/budget/test_user")
    if response.status_code == 200:
        budget_status = response.json()
        print(f"✅ Budget status:")
        print(f"   Within budget: {budget_status['within_budget']}")
        print(f"   Today's cost: ${budget_status['today_cost']:.6f}")
        print(f"   Remaining: ${budget_status['remaining_budget']:.6f}")
        print(f"   Usage: {budget_status['usage_percentage']:.1f}%\n")
    else:
        print(f"❌ Budget check failed: {response.text}\n")
    
    # 4. Get usage summary
    print("4. Getting usage summary...")
    response = requests.get(f"{BASE_URL}/usage?user_id=test_user&days=1")
    if response.status_code == 200:
        usage = response.json()
        print(f"✅ Usage summary:")
        print(f"   Total cost: ${usage['total_cost']:.6f}")
        print(f"   Total tokens: {usage['total_tokens']}")
        print(f"   Requests: {usage['request_count']}")
        print(f"   Avg cost/request: ${usage['average_cost_per_request']:.6f}\n")
    else:
        print(f"❌ Usage summary failed: {response.text}\n")
    
    # 5. Get optimization suggestions
    print("5. Getting optimization suggestions...")
    response = requests.get(f"{BASE_URL}/optimization")
    if response.status_code == 200:
        optimization = response.json()
        print(f"✅ Optimization suggestions:")
        for suggestion in optimization['suggestions']:
            print(f"   - {suggestion['message']}")
        print(f"   Total potential savings: ${optimization['total_potential_savings']:.6f}\n")
    else:
        print(f"❌ Optimization failed: {response.text}\n")
    
    # 6. Health check
    print("6. Checking service health...")
    response = requests.get(f"{BASE_URL}/health")
    if response.status_code == 200:
        health = response.json()
        print(f"✅ Service health: {health['status']}")
        print(f"   LLM service: {health['llm_service']['status']}")
        print(f"   Usage records: {health['total_usage_records']}\n")
    else:
        print(f"❌ Health check failed: {response.text}\n")
    
    # 7. List available models
    print("7. Listing available models...")
    response = requests.get(f"{BASE_URL}/models")
    if response.status_code == 200:
        models = response.json()
        print(f"✅ Available models:")
        for model, costs in models['models'].items():
            print(f"   - {model}: Input ${costs['input_cost_per_token']:.8f}/token, Output ${costs['output_cost_per_token']:.8f}/token")
    else:
        print(f"❌ Models list failed: {response.text}\n")
    
    print("🎉 Complete API test finished!")

if __name__ == "__main__":
    test_complete_api()

Run the complete application:

1
2
3
4
5
# Terminal 1: Start the FastAPI server
python main.py

# Terminal 2: Run the test
python test_api.py

Congratulations!

You’ve successfully built a production-ready multi-provider LLM service with:

  • Multi-Provider Support: Unified interface for multiple LLM providers
  • Fallback Mechanisms: Automatic failover for reliability
  • Cost Tracking: Comprehensive cost monitoring and budgeting
  • Error Handling: Robust error handling with retry logic
  • Load Balancing: Distribute requests across multiple models
  • FastAPI Integration: Type-safe REST API with automatic documentation
  • Budget Management: User-specific budget limits and monitoring
  • Optimization Suggestions: Automated cost optimization recommendations

This foundation gives you everything you need to build scalable, cost-effective LLM applications with the Pragmatic AI Stack.

Next Steps

  • Implement authentication and user management
  • Add request caching for cost optimization
  • Set up monitoring and alerting
  • Implement advanced routing strategies
  • Add support for streaming responses
  • Deploy to production with proper scaling