""" MoE Orchestrator for Document Classification.
Coordinates the full classification pipeline:
- Run all analysts in parallel
- Calculate consensus from votes
- Run judges for validation
- Produce final classification with audit trail
Per ADR-019: Multi-agent orchestration with parallel execution. Per ADR-073: Provider-aware model selection and confidence adjustment. """
import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import List, Dict, Optional, Callable, Any import logging
from .models import ( Document, AnalystVote, JudgeDecision, ConsensusResult, ClassificationResult, ApprovalType ) from .consensus import ConsensusCalculator, ConsensusConfig from .deep_analysts import DeepAnalysisOrchestrator, DeepAnalysisResult from .provider_detector import ( ProviderDetector, ProviderMode, Provider, ProviderDetectionResult, get_default_detector, reset_default_detector, )
Set up logging
logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)
@dataclass class OrchestratorConfig: """Configuration for the MoE orchestrator.""" # Parallel execution max_parallel_analysts: int = 5 max_parallel_judges: int = 3 analyst_timeout_seconds: float = 30.0 judge_timeout_seconds: float = 10.0
# Retry settings
max_retries: int = 2
retry_delay_seconds: float = 0.5
# Consensus settings
consensus_config: ConsensusConfig = field(default_factory=ConsensusConfig)
# Behavior
skip_judges_on_auto_approve: bool = False # Still run judges for audit
fail_on_analyst_error: bool = False # Continue with remaining votes
# Deep analysis for escalations
enable_deep_analysis: bool = True # Run deep analysts on escalated docs
deep_analysis_confidence_threshold: float = 0.65 # Threshold for deep analysis consensus
# Provider detection settings (ADR-073)
enable_provider_detection: bool = True # Auto-detect available providers
apply_confidence_adjustment: bool = True # Apply mode-based confidence adjustment
force_provider_mode: Optional[ProviderMode] = None # Override detected mode
log_provider_info: bool = True # Include provider info in stats/results
@dataclass class OrchestratorStats: """Statistics from orchestration run.""" documents_processed: int = 0 auto_approved: int = 0 judge_approved: int = 0 escalated: int = 0 deep_analysis_resolved: int = 0 # Escalations resolved by deep analysis human_review_required: int = 0 # Escalations still needing human review errors: int = 0 total_time_ms: int = 0 avg_analyst_time_ms: float = 0.0 avg_judge_time_ms: float = 0.0 avg_deep_analysis_time_ms: float = 0.0
# Provider detection stats (ADR-073)
provider_mode: str = "unknown"
provider_count: int = 0
available_providers: List[str] = field(default_factory=list)
confidence_adjustment: float = 0.0
diversity_strategy: str = "unknown"
total_confidence_adjustments_applied: int = 0
avg_raw_confidence: float = 0.0 # Before adjustment
avg_adjusted_confidence: float = 0.0 # After adjustment
class MoEOrchestrator: """Mixture of Experts Orchestrator for document classification.
Coordinates analysts and judges to produce high-confidence
classifications with full audit trails.
Per ADR-073: Supports provider-aware model selection and
confidence adjustment based on available LLM providers.
"""
def __init__(
self,
analysts: List, # List of BaseAnalyst instances
judges: List, # List of BaseJudge instances
config: Optional[OrchestratorConfig] = None
):
self.analysts = analysts
self.judges = judges
self.config = config or OrchestratorConfig()
self.consensus_calc = ConsensusCalculator(self.config.consensus_config)
self.stats = OrchestratorStats()
# Initialize deep analysis orchestrator for escalations
if self.config.enable_deep_analysis:
self.deep_orchestrator = DeepAnalysisOrchestrator(
confidence_threshold=self.config.deep_analysis_confidence_threshold
)
else:
self.deep_orchestrator = None
# Initialize provider detection (ADR-073)
self._provider_detection_result: Optional[ProviderDetectionResult] = None
if self.config.enable_provider_detection:
self._detect_providers()
def _detect_providers(self) -> None:
"""Detect available LLM providers and update stats (ADR-073)."""
try:
detector = get_default_detector()
self._provider_detection_result = detector.detect_mode(
force_mode=self.config.force_provider_mode
)
# Update stats with provider info
result = self._provider_detection_result
self.stats.provider_mode = result.mode.value
self.stats.provider_count = result.provider_count
self.stats.available_providers = [p.value for p in result.available_providers]
self.stats.confidence_adjustment = result.confidence_adjustment
self.stats.diversity_strategy = result.diversity_strategy
if self.config.log_provider_info:
logger.info(
f"Provider detection: mode={result.mode.value}, "
f"providers={self.stats.available_providers}, "
f"confidence_adjustment={result.confidence_adjustment:+.0%}"
)
except Exception as e:
logger.warning(f"Provider detection failed, using defaults: {e}")
self._provider_detection_result = None
self.stats.provider_mode = "unknown"
def _apply_confidence_adjustment(self, confidence: float) -> float:
"""Apply provider mode confidence adjustment (ADR-073).
Args:
confidence: Raw confidence score (0.0-1.0)
Returns:
Adjusted confidence score
"""
if not self.config.apply_confidence_adjustment:
return confidence
if self._provider_detection_result is None:
return confidence
adjustment = self._provider_detection_result.confidence_adjustment
adjusted = max(0.0, min(1.0, confidence + adjustment))
# Track stats
self.stats.total_confidence_adjustments_applied += 1
n = self.stats.total_confidence_adjustments_applied
self.stats.avg_raw_confidence = (
(self.stats.avg_raw_confidence * (n - 1) + confidence) / n
)
self.stats.avg_adjusted_confidence = (
(self.stats.avg_adjusted_confidence * (n - 1) + adjusted) / n
)
return adjusted
@property
def provider_mode(self) -> ProviderMode:
"""Get the detected provider mode."""
if self._provider_detection_result:
return self._provider_detection_result.mode
return ProviderMode.MULTI # Default to full diversity
@property
def provider_info(self) -> Dict[str, Any]:
"""Get provider detection information."""
if self._provider_detection_result:
return self._provider_detection_result.to_dict()
return {"mode": "unknown", "provider_count": 0}
def refresh_provider_detection(self) -> ProviderDetectionResult:
"""Refresh provider detection (e.g., after environment changes).
Returns:
Updated ProviderDetectionResult
"""
reset_default_detector()
self._detect_providers()
return self._provider_detection_result
def classify(self, document: Document) -> ClassificationResult:
"""Classify a single document through the full MoE pipeline.
Args:
document: Document to classify
Returns:
ClassificationResult with classification and audit trail
"""
start_time = time.time()
# Phase 1: Run all analysts
votes = self._run_analysts(document)
if not votes:
return self._create_error_result(
document,
"No analyst votes obtained",
start_time
)
# Phase 2: Calculate consensus
consensus = self.consensus_calc.calculate_from_votes(votes)
# Apply provider mode confidence adjustment (ADR-073)
if self.config.apply_confidence_adjustment and consensus.confidence:
raw_conf = consensus.confidence
adjusted_confidence = self._apply_confidence_adjustment(raw_conf)
consensus = ConsensusResult(
classification=consensus.classification,
confidence=adjusted_confidence,
agreement_ratio=consensus.agreement_ratio,
approval_type=consensus.approval_type,
votes=consensus.votes,
judge_decisions=consensus.judge_decisions,
escalation_reason=consensus.escalation_reason,
deep_analysis_reasoning=getattr(consensus, 'deep_analysis_reasoning', None),
# Provider adjustment fields (ADR-073)
provider_adjustment_applied=True,
raw_confidence=raw_conf,
provider_mode=self.stats.provider_mode
)
# Phase 3: Run judges (unless auto-approved and configured to skip)
if not (consensus.approval_type == ApprovalType.AUTO_APPROVED and
self.config.skip_judges_on_auto_approve):
decisions = self._run_judges(document, votes)
consensus = self.consensus_calc.apply_judge_decisions(consensus, decisions)
# Phase 4: Deep analysis for escalated documents
deep_analysis_result = None
if (consensus.approval_type == ApprovalType.ESCALATED and
self.deep_orchestrator is not None):
deep_start = time.time()
# Create a preliminary result for deep analysis
preliminary_result = ClassificationResult(
document_path=str(document.path),
result=consensus,
timestamp=datetime.utcnow(),
processing_time_ms=0
)
deep_analysis_result = self.deep_orchestrator.analyze_escalation(
document, preliminary_result
)
# Update consensus with deep analysis results
if deep_analysis_result.consensus_reached:
consensus = ConsensusResult(
classification=deep_analysis_result.final_classification,
confidence=deep_analysis_result.final_confidence,
agreement_ratio=consensus.agreement_ratio,
approval_type=ApprovalType.DEEP_ANALYSIS_APPROVED,
votes=consensus.votes,
judge_decisions=consensus.judge_decisions,
escalation_reason=None,
deep_analysis_reasoning=deep_analysis_result.reasoning
)
self.stats.deep_analysis_resolved += 1
else:
# Deep analysis couldn't resolve, still needs human review
consensus = ConsensusResult(
classification=deep_analysis_result.final_classification,
confidence=deep_analysis_result.final_confidence,
agreement_ratio=consensus.agreement_ratio,
approval_type=ApprovalType.HUMAN_REVIEW_REQUIRED,
votes=consensus.votes,
judge_decisions=consensus.judge_decisions,
escalation_reason=f"Deep analysis inconclusive: {deep_analysis_result.reasoning}",
deep_analysis_reasoning=deep_analysis_result.reasoning
)
self.stats.human_review_required += 1
# Track deep analysis time
deep_time = (time.time() - deep_start) * 1000
self.stats.avg_deep_analysis_time_ms = (
(self.stats.avg_deep_analysis_time_ms *
(self.stats.deep_analysis_resolved + self.stats.human_review_required - 1) +
deep_time) /
(self.stats.deep_analysis_resolved + self.stats.human_review_required)
if (self.stats.deep_analysis_resolved + self.stats.human_review_required) > 0
else deep_time
)
# Update stats
self._update_stats(consensus, start_time)
return ClassificationResult(
document_path=str(document.path),
result=consensus,
timestamp=datetime.utcnow(),
processing_time_ms=int((time.time() - start_time) * 1000),
deep_analysis=deep_analysis_result
)
def classify_batch(
self,
documents: List[Document],
progress_callback: Optional[Callable[[int, int], None]] = None
) -> List[ClassificationResult]:
"""Classify multiple documents.
Args:
documents: List of documents to classify
progress_callback: Optional callback(current, total) for progress
Returns:
List of ClassificationResult objects
"""
results = []
for i, doc in enumerate(documents):
try:
result = self.classify(doc)
results.append(result)
except Exception as e:
logger.error(f"Error classifying {doc.path}: {e}")
results.append(self._create_error_result(
doc,
f"Classification error: {str(e)}",
time.time()
))
self.stats.errors += 1
if progress_callback:
progress_callback(i + 1, len(documents))
return results
def _run_analysts(self, document: Document) -> List[AnalystVote]:
"""Run all analysts in parallel."""
votes = []
analyst_times = []
with ThreadPoolExecutor(max_workers=self.config.max_parallel_analysts) as executor:
future_to_analyst = {
executor.submit(self._run_single_analyst, analyst, document): analyst
for analyst in self.analysts
}
for future in as_completed(future_to_analyst, timeout=self.config.analyst_timeout_seconds):
analyst = future_to_analyst[future]
try:
vote = future.result()
if vote:
votes.append(vote)
analyst_times.append(vote.duration_ms)
except Exception as e:
logger.warning(f"Analyst {analyst.name} failed: {e}")
if self.config.fail_on_analyst_error:
raise
# Update avg analyst time
if analyst_times:
self.stats.avg_analyst_time_ms = sum(analyst_times) / len(analyst_times)
return votes
def _run_single_analyst(self, analyst, document: Document) -> Optional[AnalystVote]:
"""Run a single analyst with retry logic."""
for attempt in range(self.config.max_retries + 1):
try:
return analyst.analyze(document)
except Exception as e:
if attempt < self.config.max_retries:
logger.debug(f"Analyst {analyst.name} retry {attempt + 1}: {e}")
time.sleep(self.config.retry_delay_seconds)
else:
logger.warning(f"Analyst {analyst.name} failed after {self.config.max_retries + 1} attempts: {e}")
return None
return None
def _run_judges(
self,
document: Document,
votes: List[AnalystVote]
) -> List[JudgeDecision]:
"""Run all judges in parallel."""
decisions = []
judge_times = []
with ThreadPoolExecutor(max_workers=self.config.max_parallel_judges) as executor:
future_to_judge = {
executor.submit(self._run_single_judge, judge, document, votes): judge
for judge in self.judges
}
for future in as_completed(future_to_judge, timeout=self.config.judge_timeout_seconds):
judge = future_to_judge[future]
try:
decision = future.result()
if decision:
decisions.append(decision)
judge_times.append(decision.metadata.get('duration_ms', 0))
except Exception as e:
logger.warning(f"Judge {judge.name} failed: {e}")
# Create a rejection decision for failed judge
decisions.append(JudgeDecision(
judge=judge.name,
approved=False,
reason=f"Judge error: {str(e)}",
confidence=0.0,
metadata={'error': str(e)}
))
# Update avg judge time
if judge_times:
self.stats.avg_judge_time_ms = sum(judge_times) / len(judge_times)
return decisions
def _run_single_judge(
self,
judge,
document: Document,
votes: List[AnalystVote]
) -> Optional[JudgeDecision]:
"""Run a single judge."""
try:
return judge.evaluate(document, votes)
except Exception as e:
logger.warning(f"Judge {judge.name} error: {e}")
return None
def _create_error_result(
self,
document: Document,
error_message: str,
start_time: float
) -> ClassificationResult:
"""Create an error result for failed classification."""
return ClassificationResult(
document_path=str(document.path),
result=ConsensusResult(
classification=None,
confidence=0.0,
agreement_ratio=0.0,
approval_type=ApprovalType.ESCALATED,
escalation_reason=error_message
),
timestamp=datetime.utcnow(),
processing_time_ms=int((time.time() - start_time) * 1000)
)
def _update_stats(self, consensus: ConsensusResult, start_time: float):
"""Update orchestrator statistics."""
self.stats.documents_processed += 1
self.stats.total_time_ms += int((time.time() - start_time) * 1000)
if consensus.approval_type == ApprovalType.AUTO_APPROVED:
self.stats.auto_approved += 1
elif consensus.approval_type == ApprovalType.JUDGE_APPROVED:
self.stats.judge_approved += 1
elif consensus.approval_type == ApprovalType.ESCALATED:
self.stats.escalated += 1
def get_stats(self) -> Dict:
"""Get orchestrator statistics."""
total_approved = (
self.stats.auto_approved +
self.stats.judge_approved +
self.stats.deep_analysis_resolved
)
return {
'documents_processed': self.stats.documents_processed,
'auto_approved': self.stats.auto_approved,
'judge_approved': self.stats.judge_approved,
'deep_analysis_resolved': self.stats.deep_analysis_resolved,
'human_review_required': self.stats.human_review_required,
'escalated': self.stats.escalated,
'errors': self.stats.errors,
'total_time_ms': self.stats.total_time_ms,
'avg_time_per_doc_ms': (
self.stats.total_time_ms / self.stats.documents_processed
if self.stats.documents_processed > 0 else 0
),
'avg_analyst_time_ms': round(self.stats.avg_analyst_time_ms, 2),
'avg_judge_time_ms': round(self.stats.avg_judge_time_ms, 2),
'avg_deep_analysis_time_ms': round(self.stats.avg_deep_analysis_time_ms, 2),
'approval_rate': (
total_approved / self.stats.documents_processed
if self.stats.documents_processed > 0 else 0
),
'escalation_resolution_rate': (
self.stats.deep_analysis_resolved /
(self.stats.deep_analysis_resolved + self.stats.human_review_required)
if (self.stats.deep_analysis_resolved + self.stats.human_review_required) > 0
else 0
),
# Provider detection stats (ADR-073)
'provider_mode': self.stats.provider_mode,
'provider_count': self.stats.provider_count,
'available_providers': self.stats.available_providers,
'diversity_strategy': self.stats.diversity_strategy,
'confidence_adjustment': self.stats.confidence_adjustment,
'total_confidence_adjustments': self.stats.total_confidence_adjustments_applied,
'avg_raw_confidence': round(self.stats.avg_raw_confidence, 4),
'avg_adjusted_confidence': round(self.stats.avg_adjusted_confidence, 4),
}
def reset_stats(self):
"""Reset statistics."""
self.stats = OrchestratorStats()
def create_default_orchestrator() -> MoEOrchestrator: """Create an orchestrator with default analysts and judges.
Returns:
Configured MoEOrchestrator instance
"""
# Import here to avoid circular imports
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from analysts import get_all_analysts
from judges import get_all_judges
analysts = get_all_analysts()
judges = get_all_judges()
return MoEOrchestrator(analysts=analysts, judges=judges)