scripts-test-multi-model-client
""" Tests for Multi-Model Client (H.3.5.6).
Tests cover:
- ModelProvider enum and model routing
- CompletionRequest/CompletionResponse dataclasses
- FallbackConfig configuration
- MultiModelClient initialization and provider detection
- Mock completion flow with retry and fallback logic
- API key checking and available providers """
import asyncio import json import os import tempfile import unittest from datetime import datetime, timezone from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import sys sys.path.insert(0, os.path.dirname(os.path.dirname(file)))
from moe_classifier.core.multi_model_client import ( MultiModelClient, ModelProvider, CompletionRequest, CompletionResponse, FallbackConfig, ModelConfig, BaseProviderClient, AnthropicClient, OpenAIClient, DeepSeekClient, TogetherClient, GoogleClient, DashScopeClient, create_default_client, get_provider_for_model, check_api_keys, )
class TestModelProvider(unittest.TestCase): """Tests for ModelProvider enum."""
def test_provider_values(self):
"""Test all provider values are strings."""
self.assertEqual(ModelProvider.ANTHROPIC.value, "anthropic")
self.assertEqual(ModelProvider.OPENAI.value, "openai")
self.assertEqual(ModelProvider.DEEPSEEK.value, "deepseek")
self.assertEqual(ModelProvider.ALIBABA.value, "alibaba")
self.assertEqual(ModelProvider.META.value, "meta")
self.assertEqual(ModelProvider.GOOGLE.value, "google")
def test_provider_count(self):
"""Test we have exactly 6 providers."""
self.assertEqual(len(ModelProvider), 6)
def test_provider_string_enum(self):
"""Test ModelProvider is a string enum."""
for provider in ModelProvider:
self.assertIsInstance(provider.value, str)
class TestCompletionRequest(unittest.TestCase): """Tests for CompletionRequest dataclass."""
def test_basic_request(self):
"""Test creating basic completion request."""
request = CompletionRequest(
prompt="Test prompt",
model="claude-sonnet-4",
persona_id="technical_architect"
)
self.assertEqual(request.prompt, "Test prompt")
self.assertEqual(request.model, "claude-sonnet-4")
self.assertEqual(request.persona_id, "technical_architect")
self.assertEqual(request.max_tokens, 4096) # default
self.assertEqual(request.temperature, 0.0) # default
self.assertIsNone(request.system_prompt)
def test_request_with_options(self):
"""Test request with all options."""
request = CompletionRequest(
prompt="Test",
model="gpt-4o",
persona_id="security_auditor",
max_tokens=2048,
temperature=0.5,
system_prompt="You are a security expert.",
metadata={"key": "value"}
)
self.assertEqual(request.max_tokens, 2048)
self.assertEqual(request.temperature, 0.5)
self.assertEqual(request.system_prompt, "You are a security expert.")
self.assertEqual(request.metadata, {"key": "value"})
class TestCompletionResponse(unittest.TestCase): """Tests for CompletionResponse dataclass."""
def test_successful_response(self):
"""Test creating successful response."""
response = CompletionResponse(
content="The evaluation is complete.",
model_used="claude-sonnet-4",
provider="anthropic",
token_usage=150,
input_tokens=50,
output_tokens=100,
latency_ms=1200,
timestamp=datetime.now(timezone.utc)
)
self.assertTrue(response.success)
self.assertIsNone(response.error)
self.assertEqual(response.token_usage, 150)
def test_failed_response(self):
"""Test creating failed response."""
response = CompletionResponse(
content="",
model_used="gpt-4o",
provider="openai",
token_usage=0,
input_tokens=0,
output_tokens=0,
latency_ms=500,
timestamp=datetime.now(timezone.utc),
success=False,
error="API rate limit exceeded"
)
self.assertFalse(response.success)
self.assertEqual(response.error, "API rate limit exceeded")
def test_to_dict(self):
"""Test response serialization."""
ts = datetime(2026, 1, 8, 12, 0, 0, tzinfo=timezone.utc)
response = CompletionResponse(
content="Short response",
model_used="claude-sonnet-4",
provider="anthropic",
token_usage=100,
input_tokens=40,
output_tokens=60,
latency_ms=800,
timestamp=ts
)
d = response.to_dict()
self.assertEqual(d["content"], "Short response")
self.assertEqual(d["model_used"], "claude-sonnet-4")
self.assertEqual(d["provider"], "anthropic")
self.assertEqual(d["timestamp"], "2026-01-08T12:00:00+00:00")
def test_to_dict_truncates_long_content(self):
"""Test that to_dict truncates long content."""
long_content = "x" * 1000
response = CompletionResponse(
content=long_content,
model_used="claude-sonnet-4",
provider="anthropic",
token_usage=100,
input_tokens=40,
output_tokens=60,
latency_ms=800,
timestamp=datetime.now(timezone.utc)
)
d = response.to_dict()
self.assertTrue(d["content"].endswith("..."))
self.assertEqual(len(d["content"]), 503) # 500 + "..."
class TestFallbackConfig(unittest.TestCase): """Tests for FallbackConfig dataclass."""
def test_default_config(self):
"""Test default fallback configuration."""
config = FallbackConfig()
self.assertEqual(config.max_retries, 2)
self.assertEqual(config.retry_delay_seconds, 1.0)
self.assertTrue(config.use_backup_on_failure)
self.assertTrue(config.use_backup_on_timeout)
self.assertEqual(config.timeout_seconds, 60)
def test_custom_config(self):
"""Test custom fallback configuration."""
config = FallbackConfig(
max_retries=5,
retry_delay_seconds=0.5,
use_backup_on_failure=False,
use_backup_on_timeout=False,
timeout_seconds=120
)
self.assertEqual(config.max_retries, 5)
self.assertEqual(config.retry_delay_seconds, 0.5)
self.assertFalse(config.use_backup_on_failure)
self.assertFalse(config.use_backup_on_timeout)
self.assertEqual(config.timeout_seconds, 120)
class TestMultiModelClientInit(unittest.TestCase): """Tests for MultiModelClient initialization."""
def test_default_init(self):
"""Test default initialization."""
client = MultiModelClient()
self.assertIsNotNone(client.fallback_config)
self.assertIsNotNone(client.config_path)
def test_custom_fallback_config(self):
"""Test initialization with custom fallback config."""
custom_config = FallbackConfig(max_retries=5)
client = MultiModelClient(fallback_config=custom_config)
self.assertEqual(client.fallback_config.max_retries, 5)
class TestModelProviderDetection(unittest.TestCase): """Tests for model-to-provider detection."""
def setUp(self):
self.client = MultiModelClient()
def test_anthropic_models(self):
"""Test Anthropic model detection."""
self.assertEqual(
self.client._get_provider("claude-opus-4-5"),
ModelProvider.ANTHROPIC
)
self.assertEqual(
self.client._get_provider("claude-sonnet-4"),
ModelProvider.ANTHROPIC
)
self.assertEqual(
self.client._get_provider("claude-haiku-4-5"),
ModelProvider.ANTHROPIC
)
def test_openai_models(self):
"""Test OpenAI model detection."""
self.assertEqual(
self.client._get_provider("gpt-4o"),
ModelProvider.OPENAI
)
self.assertEqual(
self.client._get_provider("gpt-4o-mini"),
ModelProvider.OPENAI
)
self.assertEqual(
self.client._get_provider("gpt-4-turbo"),
ModelProvider.OPENAI
)
def test_deepseek_models(self):
"""Test DeepSeek model detection."""
self.assertEqual(
self.client._get_provider("deepseek-v3"),
ModelProvider.DEEPSEEK
)
self.assertEqual(
self.client._get_provider("deepseek-chat"),
ModelProvider.DEEPSEEK
)
def test_alibaba_models(self):
"""Test Alibaba/Qwen model detection."""
self.assertEqual(
self.client._get_provider("qwen2.5-72b"),
ModelProvider.ALIBABA
)
self.assertEqual(
self.client._get_provider("qwen-max"),
ModelProvider.ALIBABA
)
def test_meta_models(self):
"""Test Meta/Llama model detection."""
self.assertEqual(
self.client._get_provider("llama-3.3-70b"),
ModelProvider.META
)
self.assertEqual(
self.client._get_provider("llama-3.1-405b"),
ModelProvider.META
)
def test_google_models(self):
"""Test Google/Gemini model detection."""
self.assertEqual(
self.client._get_provider("gemini-2.0-flash"),
ModelProvider.GOOGLE
)
self.assertEqual(
self.client._get_provider("gemini-1.5-pro"),
ModelProvider.GOOGLE
)
def test_pattern_matching_fallback(self):
"""Test pattern matching for unlisted models."""
# Should match by pattern
self.assertEqual(
self.client._get_provider("claude-unknown-model"),
ModelProvider.ANTHROPIC
)
self.assertEqual(
self.client._get_provider("gpt-5-future"),
ModelProvider.OPENAI
)
self.assertEqual(
self.client._get_provider("deepseek-future"),
ModelProvider.DEEPSEEK
)
self.assertEqual(
self.client._get_provider("qwen3-100b"),
ModelProvider.ALIBABA
)
self.assertEqual(
self.client._get_provider("llama-4"),
ModelProvider.META
)
self.assertEqual(
self.client._get_provider("gemini-3"),
ModelProvider.GOOGLE
)
def test_unknown_model_raises(self):
"""Test that unknown models raise ValueError."""
with self.assertRaises(ValueError):
self.client._get_provider("unknown-model")
class TestModelInfo(unittest.TestCase): """Tests for model info retrieval."""
def test_get_model_info_known(self):
"""Test getting info for known model."""
client = MultiModelClient()
info = client.get_model_info("claude-sonnet-4")
self.assertEqual(info["model"], "claude-sonnet-4")
self.assertEqual(info["provider"], "anthropic")
self.assertEqual(info["api_key_env"], "ANTHROPIC_API_KEY")
def test_get_model_info_unknown(self):
"""Test getting info for unknown model."""
client = MultiModelClient()
info = client.get_model_info("unknown-xyz")
self.assertIn("error", info)
class TestBackupModel(unittest.TestCase): """Tests for backup model lookup."""
def test_backup_model_from_config(self):
"""Test getting backup model from routing config."""
# Create temp config file
config = {
"routing": {
"technical_architect": {
"primary_model": "claude-sonnet-4",
"backup_model": "gpt-4o"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config, f)
config_path = Path(f.name)
try:
client = MultiModelClient(config_path=config_path)
backup = client.get_backup_model("technical_architect")
self.assertEqual(backup, "gpt-4o")
finally:
config_path.unlink()
def test_backup_model_missing_persona(self):
"""Test backup model for unknown persona."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump({"routing": {}}, f)
config_path = Path(f.name)
try:
client = MultiModelClient(config_path=config_path)
backup = client.get_backup_model("unknown_persona")
self.assertIsNone(backup)
finally:
config_path.unlink()
class TestAPIKeyChecking(unittest.TestCase): """Tests for API key checking."""
def test_check_api_keys(self):
"""Test checking API keys."""
# Should return dict with all providers
result = check_api_keys()
self.assertIn("anthropic", result)
self.assertIn("openai", result)
self.assertIn("deepseek", result)
self.assertIn("alibaba", result)
self.assertIn("meta", result)
self.assertIn("google", result)
# Values should be booleans
for value in result.values():
self.assertIsInstance(value, bool)
@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "test-key",
"OPENAI_API_KEY": "test-key"
}, clear=True)
def test_available_providers_with_keys(self):
"""Test getting available providers with some keys set."""
# Clear other keys
for key in ["DEEPSEEK_API_KEY", "DASHSCOPE_API_KEY", "TOGETHER_API_KEY", "GOOGLE_API_KEY"]:
os.environ.pop(key, None)
client = MultiModelClient()
available = client.get_available_providers()
self.assertIn("anthropic", available)
self.assertIn("openai", available)
self.assertNotIn("deepseek", available)
class TestGetProviderForModel(unittest.TestCase): """Tests for get_provider_for_model convenience function."""
def test_get_provider_claude(self):
"""Test getting provider for Claude model."""
self.assertEqual(get_provider_for_model("claude-sonnet-4"), "anthropic")
def test_get_provider_gpt(self):
"""Test getting provider for GPT model."""
self.assertEqual(get_provider_for_model("gpt-4o"), "openai")
class TestConvenienceFunctions(unittest.TestCase): """Tests for convenience functions."""
def test_create_default_client(self):
"""Test creating default client."""
client = create_default_client()
self.assertIsInstance(client, MultiModelClient)
class TestClientCreation(unittest.TestCase): """Tests for provider client creation."""
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"})
def test_get_anthropic_client(self):
"""Test creating Anthropic client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.ANTHROPIC)
self.assertIsInstance(provider_client, AnthropicClient)
self.assertEqual(provider_client.provider, ModelProvider.ANTHROPIC)
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
def test_get_openai_client(self):
"""Test creating OpenAI client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.OPENAI)
self.assertIsInstance(provider_client, OpenAIClient)
@patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"})
def test_get_deepseek_client(self):
"""Test creating DeepSeek client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.DEEPSEEK)
self.assertIsInstance(provider_client, DeepSeekClient)
@patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test-key"})
def test_get_alibaba_client(self):
"""Test creating Alibaba/DashScope client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.ALIBABA)
self.assertIsInstance(provider_client, DashScopeClient)
@patch.dict(os.environ, {"TOGETHER_API_KEY": "test-key"})
def test_get_meta_client(self):
"""Test creating Meta/Together client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.META)
self.assertIsInstance(provider_client, TogetherClient)
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"})
def test_get_google_client(self):
"""Test creating Google client."""
client = MultiModelClient()
provider_client = client._get_client(ModelProvider.GOOGLE)
self.assertIsInstance(provider_client, GoogleClient)
def test_missing_api_key_raises(self):
"""Test that missing API key raises ValueError."""
# Clear all env vars
with patch.dict(os.environ, {}, clear=True):
client = MultiModelClient()
with self.assertRaises(ValueError) as ctx:
client._get_client(ModelProvider.ANTHROPIC)
self.assertIn("ANTHROPIC_API_KEY", str(ctx.exception))
class TestCompletionFlow(unittest.TestCase): """Tests for completion flow with mocking."""
def setUp(self):
self.client = MultiModelClient()
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"})
def test_successful_completion(self):
"""Test successful completion flow."""
async def run_test():
mock_response = CompletionResponse(
content="Evaluation complete. PASS.",
model_used="claude-sonnet-4",
provider="anthropic",
token_usage=100,
input_tokens=50,
output_tokens=50,
latency_ms=1000,
timestamp=datetime.now(timezone.utc),
success=True
)
with patch.object(
AnthropicClient,
'complete',
return_value=mock_response
):
response = await self.client.get_completion(
model="claude-sonnet-4",
prompt="Evaluate this code",
persona_id="technical_architect"
)
self.assertTrue(response.success)
self.assertEqual(response.model_used, "claude-sonnet-4")
asyncio.run(run_test())
@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "test-key",
"OPENAI_API_KEY": "backup-key"
})
def test_fallback_on_failure(self):
"""Test fallback to backup model on failure.
This test verifies that when the primary model fails, the client
correctly falls back to the backup model specified in the routing config.
"""
async def run_test():
# Create config with backup model
config = {
"routing": {
"technical_architect": {
"primary_model": "claude-sonnet-4",
"backup_model": "gpt-4o"
}
},
"fallback_strategy": {
"max_retries": 0, # Immediate fallback
"use_backup_on_failure": True
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config, f)
config_path = Path(f.name)
try:
# Primary fails
primary_response = CompletionResponse(
content="",
model_used="claude-sonnet-4",
provider="anthropic",
token_usage=0,
input_tokens=0,
output_tokens=0,
latency_ms=100,
timestamp=datetime.now(timezone.utc),
success=False,
error="Rate limit exceeded"
)
# Backup succeeds
backup_response = CompletionResponse(
content="Backup evaluation complete.",
model_used="gpt-4o",
provider="openai",
token_usage=100,
input_tokens=50,
output_tokens=50,
latency_ms=800,
timestamp=datetime.now(timezone.utc),
success=True
)
call_count = {"anthropic": 0, "openai": 0}
# Create mock client instances
mock_anthropic = MagicMock(spec=AnthropicClient)
mock_anthropic.provider = ModelProvider.ANTHROPIC
async def anthropic_complete(request):
call_count["anthropic"] += 1
return primary_response
mock_anthropic.complete = anthropic_complete
mock_openai = MagicMock(spec=OpenAIClient)
mock_openai.provider = ModelProvider.OPENAI
async def openai_complete(request):
call_count["openai"] += 1
return backup_response
mock_openai.complete = openai_complete
client = MultiModelClient(config_path=config_path)
# Pre-populate client cache with mock clients
client._clients[ModelProvider.ANTHROPIC] = mock_anthropic
client._clients[ModelProvider.OPENAI] = mock_openai
response = await client.get_completion(
model="claude-sonnet-4",
prompt="Evaluate this",
persona_id="technical_architect"
)
self.assertTrue(response.success)
self.assertEqual(response.model_used, "gpt-4o")
self.assertEqual(call_count["anthropic"], 1)
self.assertEqual(call_count["openai"], 1)
finally:
config_path.unlink()
asyncio.run(run_test())
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"})
def test_retry_logic(self):
"""Test retry logic on transient failures."""
async def run_test():
call_count = [0]
async def mock_complete_with_retry(request):
call_count[0] += 1
# Fail first 2 times, succeed on 3rd
if call_count[0] < 3:
return CompletionResponse(
content="",
model_used=request.model,
provider="anthropic",
token_usage=0,
input_tokens=0,
output_tokens=0,
latency_ms=100,
timestamp=datetime.now(timezone.utc),
success=False,
error="Transient error"
)
return CompletionResponse(
content="Success after retry",
model_used=request.model,
provider="anthropic",
token_usage=100,
input_tokens=50,
output_tokens=50,
latency_ms=500,
timestamp=datetime.now(timezone.utc),
success=True
)
# Use short retry delay for testing
client = MultiModelClient(
fallback_config=FallbackConfig(
max_retries=3,
retry_delay_seconds=0.01
)
)
with patch.object(
AnthropicClient,
'complete',
side_effect=mock_complete_with_retry
):
response = await client.get_completion(
model="claude-sonnet-4",
prompt="Test retry",
persona_id="test",
use_fallback=False
)
self.assertTrue(response.success)
self.assertEqual(call_count[0], 3)
asyncio.run(run_test())
class TestFallbackConfigLoading(unittest.TestCase): """Tests for loading fallback config from file."""
def test_load_fallback_from_config_file(self):
"""Test loading fallback config from JSON file."""
config = {
"fallback_strategy": {
"max_retries": 5,
"retry_delay_seconds": 2.0,
"use_backup_on_failure": False,
"timeout_seconds": 120
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config, f)
config_path = Path(f.name)
try:
client = MultiModelClient(config_path=config_path)
self.assertEqual(client.fallback_config.max_retries, 5)
self.assertEqual(client.fallback_config.retry_delay_seconds, 2.0)
self.assertFalse(client.fallback_config.use_backup_on_failure)
self.assertEqual(client.fallback_config.timeout_seconds, 120)
finally:
config_path.unlink()
class TestProviderClients(unittest.TestCase): """Tests for individual provider clients."""
def test_anthropic_client_provider(self):
"""Test Anthropic client provider property."""
client = AnthropicClient("test-key")
self.assertEqual(client.provider, ModelProvider.ANTHROPIC)
def test_openai_client_provider(self):
"""Test OpenAI client provider property."""
client = OpenAIClient("test-key")
self.assertEqual(client.provider, ModelProvider.OPENAI)
def test_deepseek_client_provider(self):
"""Test DeepSeek client provider property."""
client = DeepSeekClient("test-key")
self.assertEqual(client.provider, ModelProvider.DEEPSEEK)
self.assertEqual(client.BASE_URL, "https://api.deepseek.com")
def test_together_client_provider(self):
"""Test Together client provider property."""
client = TogetherClient("test-key")
self.assertEqual(client.provider, ModelProvider.META)
self.assertEqual(client.BASE_URL, "https://api.together.xyz/v1")
def test_google_client_provider(self):
"""Test Google client provider property."""
client = GoogleClient("test-key")
self.assertEqual(client.provider, ModelProvider.GOOGLE)
def test_dashscope_client_provider(self):
"""Test DashScope client provider property."""
client = DashScopeClient("test-key")
self.assertEqual(client.provider, ModelProvider.ALIBABA)
class TestModelConfig(unittest.TestCase): """Tests for ModelConfig dataclass."""
def test_model_config_defaults(self):
"""Test ModelConfig default values."""
config = ModelConfig(
model_id="claude-sonnet-4",
provider=ModelProvider.ANTHROPIC,
api_key_env="ANTHROPIC_API_KEY"
)
self.assertEqual(config.max_tokens, 4096)
self.assertEqual(config.temperature, 0.0)
self.assertEqual(config.timeout_seconds, 60)
def test_model_config_custom(self):
"""Test ModelConfig with custom values."""
config = ModelConfig(
model_id="gpt-4o",
provider=ModelProvider.OPENAI,
api_key_env="OPENAI_API_KEY",
max_tokens=8192,
temperature=0.5,
timeout_seconds=120
)
self.assertEqual(config.max_tokens, 8192)
self.assertEqual(config.temperature, 0.5)
self.assertEqual(config.timeout_seconds, 120)
if name == 'main': unittest.main()