Skip to main content

scripts-adaptive-thresholds

#!/usr/bin/env python3 """Adaptive Threshold System for MoE Classification."""

import json import logging from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List, Optional from pathlib import Path

logger = logging.getLogger(name)

@dataclass class ThresholdConfig: base_auto_approval: float = 0.90 base_judge_approval: float = 0.85 base_agreement: float = 0.60 min_samples: int = 50 target_escalation_rate: float = 0.10 state_file: str = "threshold_state.json"

@dataclass class ThresholdState: auto_approval: float judge_approval: float agreement: float total: int = 0 escalated: int = 0 correct: int = 0 confirmed: int = 0 history: List[Dict] = field(default_factory=list)

@property
def escalation_rate(self) -> float:
return self.escalated / self.total if self.total else 0.0

@property
def accuracy_rate(self) -> float:
return self.correct / self.confirmed if self.confirmed else 0.0

class AdaptiveThresholdManager: def init(self, config: Optional[ThresholdConfig] = None): self.config = config or ThresholdConfig() self.state = self._load_state()

def _load_state(self) -> ThresholdState:
path = Path(self.config.state_file)
if path.exists():
try:
data = json.loads(path.read_text())
return ThresholdState(
auto_approval=data.get('auto_approval', self.config.base_auto_approval),
judge_approval=data.get('judge_approval', self.config.base_judge_approval),
agreement=data.get('agreement', self.config.base_agreement),
total=data.get('total', 0),
escalated=data.get('escalated', 0),
correct=data.get('correct', 0),
confirmed=data.get('confirmed', 0),
history=data.get('history', []))
except Exception:
pass
return ThresholdState(self.config.base_auto_approval, self.config.base_judge_approval, self.config.base_agreement)

def _save_state(self):
try:
Path(self.config.state_file).write_text(json.dumps({
'auto_approval': self.state.auto_approval,
'judge_approval': self.state.judge_approval,
'agreement': self.state.agreement,
'total': self.state.total,
'escalated': self.state.escalated,
'correct': self.state.correct,
'confirmed': self.state.confirmed,
'history': self.state.history[-50:]
}, indent=2))
except Exception as e:
logger.warning(f"Failed to save state: {e}")

def record_classification(self, confidence: float, was_escalated: bool, approval_type: str):
self.state.total += 1
if was_escalated:
self.state.escalated += 1
if self.state.total % 100 == 0:
self._adjust()

def record_confirmation(self, was_correct: bool):
self.state.confirmed += 1
if was_correct:
self.state.correct += 1
if self.state.confirmed % 50 == 0:
self._adjust()

def _adjust(self):
if self.state.total < self.config.min_samples:
return
esc_rate = self.state.escalation_rate
acc_rate = self.state.accuracy_rate if self.state.confirmed else 0.85

if esc_rate > self.config.target_escalation_rate * 1.5:
adj = -0.02 * (esc_rate - self.config.target_escalation_rate)
self.state.auto_approval = max(0.80, self.state.auto_approval + adj)
self.state.judge_approval = max(0.70, self.state.judge_approval + adj)
logger.info(f"Lowered thresholds due to high escalation ({esc_rate:.1%})")
elif esc_rate < self.config.target_escalation_rate * 0.5 and acc_rate < 0.90:
adj = 0.02 * (0.90 - acc_rate)
self.state.auto_approval = min(0.95, self.state.auto_approval + adj)
self.state.judge_approval = min(0.92, self.state.judge_approval + adj)
logger.info(f"Raised thresholds due to low accuracy ({acc_rate:.1%})")

self.state.history.append({'time': datetime.now().isoformat(), 'esc': esc_rate, 'acc': acc_rate})
self._save_state()

def get_thresholds(self) -> Dict[str, float]:
return {'auto_approval': self.state.auto_approval, 'judge_approval': self.state.judge_approval, 'agreement': self.state.agreement}

def get_auto_approval_threshold(self) -> float:
return self.state.auto_approval

def get_judge_approval_threshold(self) -> float:
return self.state.judge_approval

def get_stats(self) -> Dict:
return {'thresholds': self.get_thresholds(), 'total': self.state.total, 'escalation_rate': self.state.escalation_rate, 'accuracy': self.state.accuracy_rate}

_manager: Optional[AdaptiveThresholdManager] = None

def get_threshold_manager(config: Optional[ThresholdConfig] = None) -> AdaptiveThresholdManager: global _manager if _manager is None: _manager = AdaptiveThresholdManager(config) return _manager