ADR-007-v4: AI Router Architecture - Part 2: Technical
Document Specification Block​
Document: ADR-007-v4-ai-router-architecture-part2-technical
Version: 1.0.0
Purpose: Technical implementation details for CODITECT's multi-provider AI routing system
Audience: Developers, AI agents, technical implementers
Date Created: 2025-08-30
Date Modified: 2025-08-30
Status: DRAFT
Table of Contents​
- Implementation Overview
- Provider Architecture
- Routing Algorithm
- Provider Implementations
- Caching Strategy
- Performance Monitoring
- Testing Requirements
- Configuration
- Previous Part
Implementation Overview​
System Architecture​
Core Components​
// src/ai/router/mod.rs use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)] pub enum TaskType { CodeGeneration, CodeReview, Documentation, DataAnalysis, GeneralQuery, Translation, Summarization, CreativeWriting, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCapability {
pub task_types: Vec
#[derive(Debug, Clone)]
pub struct AIRouter {
providers: Arc<RwLock<HashMap<String, Box
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
pub default_provider: String,
pub fallback_providers: Vec
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoutingDecision { pub provider: String, pub reasoning: String, pub estimated_cost: f64, pub estimated_latency_ms: u64, pub confidence_score: f64, }
impl AIRouter { pub async fn route_request( &self, request: &AIRequest, ) -> Result<RoutingDecision, RouterError> { // Check cache first if let Some(cached) = self.cache.get(&request.cache_key()).await? { return Ok(RoutingDecision { provider: "cache".to_string(), reasoning: "Cached response available".to_string(), estimated_cost: 0.0, estimated_latency_ms: 1, confidence_score: 1.0, }); }
// Classify task type
let task_type = self.classify_task(request).await?;
// Get available providers
let providers = self.get_available_providers(&task_type).await?;
// Score each provider
let scores = self.score_providers(&providers, request, &task_type).await?;
// Select best provider
let decision = self.select_provider(scores, request).await?;
// Record decision
self.metrics.record_routing_decision(&decision).await;
Ok(decision)
}
async fn classify_task(&self, request: &AIRequest) -> Result<TaskType, RouterError> {
// Simple keyword-based classification (can be enhanced with ML)
let prompt_lower = request.prompt.to_lowercase();
if prompt_lower.contains("code") || prompt_lower.contains("function") ||
prompt_lower.contains("implement") {
Ok(TaskType::CodeGeneration)
} else if prompt_lower.contains("review") || prompt_lower.contains("analyze code") {
Ok(TaskType::CodeReview)
} else if prompt_lower.contains("document") || prompt_lower.contains("explain") {
Ok(TaskType::Documentation)
} else if prompt_lower.contains("data") || prompt_lower.contains("analyze") {
Ok(TaskType::DataAnalysis)
} else if prompt_lower.contains("translate") {
Ok(TaskType::Translation)
} else if prompt_lower.contains("summarize") {
Ok(TaskType::Summarization)
} else if prompt_lower.contains("write") || prompt_lower.contains("create") {
Ok(TaskType::CreativeWriting)
} else {
Ok(TaskType::GeneralQuery)
}
}
async fn score_providers(
&self,
providers: &[String],
request: &AIRequest,
task_type: &TaskType,
) -> Result<Vec<(String, f64)>, RouterError> {
let mut scores = Vec::new();
let capabilities = self.capabilities.read().await;
for provider in providers {
if let Some(cap) = capabilities.get(provider) {
let mut score = 0.0;
// Task compatibility (40% weight)
if cap.task_types.contains(task_type) {
score += 40.0;
}
// Cost efficiency (30% weight)
let cost_score = (100.0 - cap.cost_per_million_tokens) / 100.0 * 30.0;
score += cost_score;
// Performance (20% weight)
let perf_score = (1000.0 - cap.average_latency_ms as f64) / 1000.0 * 20.0;
score += perf_score.max(0.0);
// Reliability (10% weight)
score += cap.success_rate * 10.0;
// Check budget constraints
if let Some(limit) = self.config.budget_limits.get(&request.tenant_id) {
let monthly_usage = self.metrics.get_monthly_usage(&request.tenant_id, provider).await?;
if monthly_usage >= *limit {
score = 0.0; // Exclude if over budget
}
}
scores.push((provider.clone(), score));
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(scores)
}
}
Provider Architecture​
Provider Trait​
// src/ai/providers/mod.rs use async_trait::async_trait; use serde_json::Value;
#[async_trait] pub trait AIProvider: Send + Sync { fn name(&self) -> &str;
async fn complete(
&self,
request: &AIRequest,
) -> Result<AIResponse, ProviderError>;
async fn stream(
&self,
request: &AIRequest,
) -> Result<Box<dyn Stream<Item = Result<String, ProviderError>>>, ProviderError>;
async fn health_check(&self) -> Result<HealthStatus, ProviderError>;
fn capabilities(&self) -> &ProviderCapability;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AIRequest {
pub tenant_id: String,
pub user_id: String,
pub prompt: String,
pub max_tokens: Option
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AIResponse { pub provider: String, pub model: String, pub content: String, pub usage: TokenUsage, pub latency_ms: u64, pub cost: f64, pub cached: bool, }
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenUsage { pub prompt_tokens: usize, pub completion_tokens: usize, pub total_tokens: usize, }
#[derive(Debug, Clone)] pub enum HealthStatus { Healthy, Degraded(String), Unhealthy(String), }
Routing Algorithm​
Selection Process​
// src/ai/router/selector.rs impl AIRouter { pub async fn select_provider( &self, scores: Vec<(String, f64)>, request: &AIRequest, ) -> Result<RoutingDecision, RouterError> { // Get circuit breaker states let breaker_states = self.get_circuit_breaker_states().await?;
// Filter out providers with open circuits
let available_scores: Vec<_> = scores
.into_iter()
.filter(|(provider, _)| {
breaker_states.get(provider)
.map(|state| !state.is_open())
.unwrap_or(true)
})
.collect();
if available_scores.is_empty() {
return Err(RouterError::NoAvailableProviders);
}
// Select top scoring provider
let (provider, score) = &available_scores[0];
// Get provider details
let capabilities = self.capabilities.read().await;
let cap = capabilities.get(provider)
.ok_or_else(|| RouterError::ProviderNotFound(provider.clone()))?;
// Calculate estimated cost
let estimated_tokens = estimate_tokens(&request.prompt) +
request.max_tokens.unwrap_or(500);
let estimated_cost = (estimated_tokens as f64 / 1_000_000.0) *
cap.cost_per_million_tokens;
Ok(RoutingDecision {
provider: provider.clone(),
reasoning: format!(
"Selected {} for {:?} task (score: {:.2})",
provider, self.classify_task(request).await?, score
),
estimated_cost,
estimated_latency_ms: cap.average_latency_ms,
confidence_score: score / 100.0,
})
}
}
// Circuit breaker implementation pub struct CircuitBreaker { failure_count: AtomicU32, last_failure: AtomicU64, state: AtomicU8, // 0=closed, 1=open, 2=half-open }
impl CircuitBreaker { pub fn is_open(&self) -> bool { self.state.load(Ordering::Relaxed) == 1 }
pub fn record_success(&self) {
self.failure_count.store(0, Ordering::Relaxed);
self.state.store(0, Ordering::Relaxed);
}
pub fn record_failure(&self, threshold: u32, timeout_ms: u64) {
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if failures >= threshold {
self.state.store(1, Ordering::Relaxed);
self.last_failure.store(
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64,
Ordering::Relaxed
);
}
}
}
Provider Implementations​
Claude Provider​
// src/ai/providers/claude.rs use anthropic_sdk::{Client, Message, Model};
pub struct ClaudeProvider { client: Client, capability: ProviderCapability, }
impl ClaudeProvider { pub fn new(api_key: String) -> Self { Self { client: Client::new(api_key), capability: ProviderCapability { task_types: vec![ TaskType::CodeGeneration, TaskType::CodeReview, TaskType::Documentation, TaskType::DataAnalysis, TaskType::GeneralQuery, ], max_tokens: 100_000, supports_streaming: true, supports_functions: true, cost_per_million_tokens: 15.0, average_latency_ms: 2000, success_rate: 0.995, }, } } }
#[async_trait] impl AIProvider for ClaudeProvider { fn name(&self) -> &str { "claude" }
async fn complete(&self, request: &AIRequest) -> Result<AIResponse, ProviderError> {
let start = Instant::now();
let message = self.client
.messages()
.model(Model::Claude3Sonnet)
.max_tokens(request.max_tokens.unwrap_or(4096))
.temperature(request.temperature.unwrap_or(0.7))
.system(request.system_prompt.as_deref().unwrap_or(""))
.messages(&[Message::user(request.prompt.clone())])
.create()
.await
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
let latency_ms = start.elapsed().as_millis() as u64;
let usage = TokenUsage {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: message.usage.input_tokens + message.usage.output_tokens,
};
let cost = (usage.total_tokens as f64 / 1_000_000.0) *
self.capability.cost_per_million_tokens;
Ok(AIResponse {
provider: self.name().to_string(),
model: "claude-3-sonnet".to_string(),
content: message.content[0].text.clone(),
usage,
latency_ms,
cost,
cached: false,
})
}
async fn health_check(&self) -> Result<HealthStatus, ProviderError> {
// Simple health check with minimal request
match self.client.messages()
.model(Model::Claude3Haiku)
.max_tokens(10)
.messages(&[Message::user("Hi".to_string())])
.create()
.await {
Ok(_) => Ok(HealthStatus::Healthy),
Err(e) => Ok(HealthStatus::Unhealthy(e.to_string())),
}
}
fn capabilities(&self) -> &ProviderCapability {
&self.capability
}
}
Ollama Provider (Local/Free)​
// src/ai/providers/ollama.rs pub struct OllamaProvider { base_url: String, client: reqwest::Client, capability: ProviderCapability, }
impl OllamaProvider { pub fn new(base_url: String) -> Self { Self { base_url, client: reqwest::Client::new(), capability: ProviderCapability { task_types: vec![ TaskType::GeneralQuery, TaskType::Summarization, TaskType::Translation, ], max_tokens: 8192, supports_streaming: true, supports_functions: false, cost_per_million_tokens: 0.0, // Free/self-hosted average_latency_ms: 500, success_rate: 0.98, }, } } }
Caching Strategy​
// src/ai/router/cache.rs use redis::{AsyncCommands, Client}; use sha2::{Sha256, Digest};
pub struct ResponseCache { redis: Client, ttl_seconds: u64, }
impl ResponseCache {
pub async fn get(&self, key: &str) -> Result<Option
match data {
Some(bytes) => {
let response = serde_json::from_slice(&bytes)?;
Ok(Some(response))
}
None => Ok(None),
}
}
pub async fn set(&self, key: &str, response: &AIResponse) -> Result<(), CacheError> {
let mut conn = self.redis.get_async_connection().await?;
let data = serde_json::to_vec(response)?;
conn.set_ex(key, data, self.ttl_seconds as usize).await?;
Ok(())
}
}
impl AIRequest { pub fn cache_key(&self) -> String { let mut hasher = Sha256::new(); hasher.update(&self.prompt); if let Some(system) = &self.system_prompt { hasher.update(system); } hasher.update(self.temperature.unwrap_or(0.7).to_string()); hasher.update(self.max_tokens.unwrap_or(0).to_string()); format!("ai:cache:{:x}", hasher.finalize()) } }
Performance Monitoring​
// src/ai/router/metrics.rs use prometheus::{Counter, Histogram, Registry};
pub struct MetricsCollector { db: Database, request_counter: Counter, latency_histogram: Histogram, cost_counter: Counter, error_counter: Counter, }
impl MetricsCollector { pub async fn record_request( &self, tenant_id: &str, provider: &str, response: &AIResponse, ) -> Result<(), MetricsError> { // Prometheus metrics self.request_counter.inc(); self.latency_histogram.observe(response.latency_ms as f64); self.cost_counter.inc_by(response.cost);
// Store in FoundationDB for analytics
let key = format!(
"{}/ai_usage/{}/{}/{}",
tenant_id,
Utc::now().format("%Y-%m-%d"),
provider,
Uuid::new_v4()
);
let usage_record = UsageRecord {
timestamp: Utc::now(),
tenant_id: tenant_id.to_string(),
provider: provider.to_string(),
model: response.model.clone(),
prompt_tokens: response.usage.prompt_tokens,
completion_tokens: response.usage.completion_tokens,
cost: response.cost,
latency_ms: response.latency_ms,
cached: response.cached,
};
let tr = self.db.create_trx()?;
tr.set(key.as_bytes(), &serde_json::to_vec(&usage_record)?);
tr.commit().await?;
Ok(())
}
pub async fn get_monthly_usage(
&self,
tenant_id: &str,
provider: &str,
) -> Result<f64, MetricsError> {
let start_key = format!(
"{}/ai_usage/{}/{}/",
tenant_id,
Utc::now().format("%Y-%m"),
provider
);
let end_key = format!("{}\xff", start_key);
let tr = self.db.create_trx()?;
let range = tr.get_range(
start_key.as_bytes(),
end_key.as_bytes(),
RangeOption::default()
).await?;
let mut total_cost = 0.0;
for kv in range {
let record: UsageRecord = serde_json::from_slice(&kv.value())?;
total_cost += record.cost;
}
Ok(total_cost)
}
}
Testing Requirements​
#[cfg(test)] mod tests { use super::*;
#[tokio::test]
async fn test_task_classification() {
let router = create_test_router().await;
let test_cases = vec![
("Write a function to sort an array", TaskType::CodeGeneration),
("Review this pull request", TaskType::CodeReview),
("Document this API endpoint", TaskType::Documentation),
("Analyze sales data for Q3", TaskType::DataAnalysis),
("Translate to Spanish", TaskType::Translation),
("Summarize this article", TaskType::Summarization),
("Write a blog post", TaskType::CreativeWriting),
("What is the weather?", TaskType::GeneralQuery),
];
for (prompt, expected) in test_cases {
let request = AIRequest {
prompt: prompt.to_string(),
..Default::default()
};
let task_type = router.classify_task(&request).await.unwrap();
assert_eq!(task_type, expected);
}
}
#[tokio::test]
async fn test_provider_selection() {
let router = create_test_router().await;
// Test code generation prefers Claude
let request = AIRequest {
prompt: "Write a Rust function".to_string(),
tenant_id: "test".to_string(),
..Default::default()
};
let decision = router.route_request(&request).await.unwrap();
assert_eq!(decision.provider, "claude");
// Test simple query uses Ollama (free)
let request = AIRequest {
prompt: "What time is it?".to_string(),
tenant_id: "test".to_string(),
..Default::default()
};
let decision = router.route_request(&request).await.unwrap();
assert_eq!(decision.provider, "ollama");
}
#[tokio::test]
async fn test_circuit_breaker() {
let breaker = CircuitBreaker::default();
// Record failures
for _ in 0..5 {
breaker.record_failure(5, 60000);
}
assert!(breaker.is_open());
// Wait and record success
breaker.record_success();
assert!(!breaker.is_open());
}
#[tokio::test]
async fn test_cache_key_generation() {
let request1 = AIRequest {
prompt: "Test prompt".to_string(),
temperature: Some(0.7),
max_tokens: Some(100),
..Default::default()
};
let request2 = AIRequest {
prompt: "Test prompt".to_string(),
temperature: Some(0.7),
max_tokens: Some(100),
..Default::default()
};
let request3 = AIRequest {
prompt: "Different prompt".to_string(),
temperature: Some(0.7),
max_tokens: Some(100),
..Default::default()
};
assert_eq!(request1.cache_key(), request2.cache_key());
assert_ne!(request1.cache_key(), request3.cache_key());
}
}
Configuration​
config/ai-router.yaml
router: default_provider: claude fallback_providers: - openai - gemini - ollama
cache: enabled: true ttl_seconds: 3600 redis_url: "redis://localhost:6379"
circuit_breaker: failure_threshold: 5 timeout_ms: 60000 half_open_requests: 3
retry: max_attempts: 3 backoff_ms: [100, 500, 1000]
providers: claude: api_key: "${ANTHROPIC_API_KEY}" models: - claude-3-opus-20240229 - claude-3-sonnet-20240229 - claude-3-haiku-20240307
openai: api_key: "${OPENAI_API_KEY}" models: - gpt-4-turbo-preview - gpt-3.5-turbo
gemini: api_key: "${GOOGLE_API_KEY}" models: - gemini-pro - gemini-pro-vision
ollama: base_url: "http://localhost:11434" models: - llama2 - mistral - codellama
task_preferences: code_generation: preferred: [claude, openai] min_model_size: large
code_review: preferred: [claude] min_model_size: large
general_query: preferred: [ollama, gemini] min_model_size: small
data_analysis: preferred: [gemini, claude] min_model_size: medium
budget_rules: default_monthly_limit: 1000.0 provider_limits: claude: 500.0 openai: 300.0 gemini: 200.0
overage_action: downgrade # or 'block'
Previous Part​
Previous: See Part 1: Human Narrative for business context and value proposition.