Skip to main content

scripts-test-multi-model-provider-integration

""" Tests for MultiModelClient and ProviderDetector Integration (ADR-073).

Tests the integration between MultiModelClient and ProviderDetector for provider-aware model selection and routing. """

import os import unittest from unittest.mock import patch, MagicMock

from core.multi_model_client import ( MultiModelClient, ModelProvider, FallbackConfig, create_default_client, get_provider_for_model, check_api_keys, get_client_provider_mode, get_model_for_persona_from_client, ) from core.provider_detector import ( ProviderMode, Provider, reset_default_detector, )

class TestMultiModelClientProviderDetection(unittest.TestCase): """Test MultiModelClient provider detection integration."""

def setUp(self):
"""Reset detector before each test."""
reset_default_detector()

def tearDown(self):
"""Clean up after each test."""
reset_default_detector()

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_client_detects_single_provider(self):
"""Test that client detects single provider mode."""
client = MultiModelClient()

self.assertEqual(client.provider_mode, ProviderMode.SINGLE)
self.assertEqual(client.provider_count, 1)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True)
def test_client_detects_dual_provider(self):
"""Test that client detects dual provider mode."""
client = MultiModelClient()

self.assertEqual(client.provider_mode, ProviderMode.DUAL)
self.assertEqual(client.provider_count, 2)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
"DEEPSEEK_API_KEY": "sk-deepseek-test",
}, clear=True)
def test_client_detects_multi_provider(self):
"""Test that client detects multi provider mode."""
client = MultiModelClient()

self.assertEqual(client.provider_mode, ProviderMode.MULTI)
self.assertEqual(client.provider_count, 3)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_provider_detection_disabled(self):
"""Test that provider detection can be disabled."""
client = MultiModelClient(enable_provider_detection=False)

# Should default to MULTI when detection is disabled
self.assertEqual(client.provider_mode, ProviderMode.MULTI)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_force_provider_mode(self):
"""Test that provider mode can be forced."""
client = MultiModelClient(force_provider_mode=ProviderMode.MULTI)

# Mode should be forced to MULTI
self.assertEqual(client.provider_mode, ProviderMode.MULTI)

class TestMultiModelClientModelSelection(unittest.TestCase): """Test MultiModelClient provider-aware model selection."""

def setUp(self):
reset_default_detector()

def tearDown(self):
reset_default_detector()

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_get_model_for_persona_single_provider(self):
"""Test model selection in single provider mode."""
client = MultiModelClient()

# In single provider mode with Anthropic, should return Claude models
model = client.get_model_for_persona("technical_architect")
self.assertIn("claude", model.lower())

@patch.dict(os.environ, {
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True)
def test_get_model_for_persona_single_openai(self):
"""Test model selection with single OpenAI provider."""
client = MultiModelClient()

model = client.get_model_for_persona("technical_architect")
# Should be an OpenAI model
self.assertTrue(
"gpt" in model.lower() or "o3" in model.lower(),
f"Expected OpenAI model, got {model}"
)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"CODITECT_JUDGE_MODEL_TECHNICAL_ARCHITECT": "custom-model",
}, clear=True)
def test_env_override_takes_priority(self):
"""Test that environment variable overrides provider detection."""
client = MultiModelClient()

model = client.get_model_for_persona("technical_architect")
self.assertEqual(model, "custom-model")

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_override_model_takes_priority(self):
"""Test that override_model takes priority."""
client = MultiModelClient()

model = client.get_model_for_persona(
"technical_architect",
override_model="explicit-model"
)
self.assertEqual(model, "explicit-model")

class TestMultiModelClientProviderInfo(unittest.TestCase): """Test MultiModelClient provider info properties."""

def setUp(self):
reset_default_detector()

def tearDown(self):
reset_default_detector()

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True)
def test_provider_info_property(self):
"""Test provider_info property."""
client = MultiModelClient()

info = client.provider_info
self.assertIn("mode", info)
self.assertIn("provider_count", info)
self.assertEqual(info["mode"], "dual")
self.assertEqual(info["provider_count"], 2)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_refresh_provider_detection(self):
"""Test refreshing provider detection."""
client = MultiModelClient()

# Initially single provider
self.assertEqual(client.provider_mode, ProviderMode.SINGLE)

# Simulate adding a new provider
with patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True):
result = client.refresh_provider_detection()

# Should now detect dual provider
self.assertEqual(client.provider_mode, ProviderMode.DUAL)
self.assertEqual(result.mode, ProviderMode.DUAL)

class TestModelProvidersMapping(unittest.TestCase): """Test MODEL_PROVIDERS mapping includes latest models."""

def test_anthropic_models(self):
"""Test Anthropic models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("claude-opus-4-5"), ModelProvider.ANTHROPIC)
self.assertEqual(client._get_provider("claude-sonnet-4-5"), ModelProvider.ANTHROPIC)
self.assertEqual(client._get_provider("claude-haiku-4-5"), ModelProvider.ANTHROPIC)

def test_openai_models(self):
"""Test OpenAI models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("o3"), ModelProvider.OPENAI)
self.assertEqual(client._get_provider("o3-mini"), ModelProvider.OPENAI)
self.assertEqual(client._get_provider("gpt-4.1"), ModelProvider.OPENAI)
self.assertEqual(client._get_provider("gpt-4.1-mini"), ModelProvider.OPENAI)

def test_deepseek_models(self):
"""Test DeepSeek models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("deepseek-v3.2"), ModelProvider.DEEPSEEK)
self.assertEqual(client._get_provider("deepseek-reasoner"), ModelProvider.DEEPSEEK)

def test_meta_models(self):
"""Test Meta models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("llama-4-maverick"), ModelProvider.META)
self.assertEqual(client._get_provider("llama-4-scout"), ModelProvider.META)

def test_google_models(self):
"""Test Google models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("gemini-3-pro"), ModelProvider.GOOGLE)
self.assertEqual(client._get_provider("gemini-2.5-flash"), ModelProvider.GOOGLE)

def test_alibaba_models(self):
"""Test Alibaba/Qwen models are mapped correctly."""
client = MultiModelClient(enable_provider_detection=False)

self.assertEqual(client._get_provider("qwen3-72b"), ModelProvider.ALIBABA)
self.assertEqual(client._get_provider("qwen2.5-72b"), ModelProvider.ALIBABA)

class TestConvenienceFunctions(unittest.TestCase): """Test convenience functions."""

def setUp(self):
reset_default_detector()

def tearDown(self):
reset_default_detector()

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True)
def test_create_default_client(self):
"""Test create_default_client function."""
client = create_default_client()
self.assertIsInstance(client, MultiModelClient)
self.assertEqual(client.provider_mode, ProviderMode.DUAL)

def test_get_provider_for_model(self):
"""Test get_provider_for_model function."""
provider = get_provider_for_model("claude-sonnet-4")
self.assertEqual(provider, "anthropic")

provider = get_provider_for_model("gpt-4.1")
self.assertEqual(provider, "openai")

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
"OPENAI_API_KEY": "sk-openai-test",
}, clear=True)
def test_check_api_keys(self):
"""Test check_api_keys function."""
keys = check_api_keys()

self.assertTrue(keys["anthropic"])
self.assertTrue(keys["openai"])
self.assertFalse(keys["deepseek"])

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_get_client_provider_mode(self):
"""Test get_client_provider_mode function."""
mode = get_client_provider_mode()
self.assertEqual(mode, ProviderMode.SINGLE)

@patch.dict(os.environ, {
"ANTHROPIC_API_KEY": "sk-ant-test",
}, clear=True)
def test_get_model_for_persona_from_client(self):
"""Test get_model_for_persona_from_client function."""
model = get_model_for_persona_from_client("technical_architect")
self.assertIn("claude", model.lower())

if name == "main": unittest.main()