scripts-calibration
#!/usr/bin/env python3 """Confidence Calibration System for MoE Classification."""
import json import logging from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List, Tuple, Optional from pathlib import Path
logger = logging.getLogger(name)
@dataclass class CalibrationConfig: bins: int = 10 min_samples: int = 100 state_file: str = "calibration_state.json"
@dataclass class CalibrationData: confidence: float was_correct: bool timestamp: datetime = field(default_factory=datetime.now)
class ConfidenceCalibrator: """Validates and calibrates confidence scores to ensure 90% confidence = 90% accuracy."""
def __init__(self, config: Optional[CalibrationConfig] = None):
self.config = config or CalibrationConfig()
self.data: List[CalibrationData] = []
self._load_state()
def _load_state(self):
path = Path(self.config.state_file)
if path.exists():
try:
saved = json.loads(path.read_text())
self.data = [CalibrationData(d['confidence'], d['was_correct'],
datetime.fromisoformat(d['timestamp'])) for d in saved.get('data', [])]
except Exception:
pass
def _save_state(self):
try:
Path(self.config.state_file).write_text(json.dumps({
'data': [{'confidence': d.confidence, 'was_correct': d.was_correct,
'timestamp': d.timestamp.isoformat()} for d in self.data[-5000:]]
}))
except Exception as e:
logger.warning(f"Failed to save calibration state: {e}")
def record(self, confidence: float, was_correct: bool):
self.data.append(CalibrationData(confidence, was_correct))
if len(self.data) % 100 == 0:
self._save_state()
def get_calibration_curve(self) -> Dict:
if len(self.data) < self.config.min_samples:
return {"warning": f"Need {self.config.min_samples} samples, have {len(self.data)}"}
bins = self.config.bins
bin_data = [[] for _ in range(bins)]
for d in self.data:
bin_idx = min(int(d.confidence * bins), bins - 1)
bin_data[bin_idx].append(1 if d.was_correct else 0)
curve = {'predicted': [], 'actual': [], 'counts': []}
for i, samples in enumerate(bin_data):
if samples:
curve['predicted'].append((i + 0.5) / bins)
curve['actual'].append(sum(samples) / len(samples))
curve['counts'].append(len(samples))
curve['ece'] = self._calculate_ece()
return curve
def _calculate_ece(self) -> float:
"""Calculate Expected Calibration Error."""
if len(self.data) < self.config.min_samples:
return 1.0
bins = self.config.bins
bin_data = [[] for _ in range(bins)]
bin_conf = [[] for _ in range(bins)]
for d in self.data:
idx = min(int(d.confidence * bins), bins - 1)
bin_data[idx].append(1 if d.was_correct else 0)
bin_conf[idx].append(d.confidence)
ece = 0.0
for samples, confs in zip(bin_data, bin_conf):
if samples:
avg_acc = sum(samples) / len(samples)
avg_conf = sum(confs) / len(confs)
ece += len(samples) * abs(avg_acc - avg_conf)
return ece / len(self.data)
def calibrate(self, raw_confidence: float) -> float:
"""Apply calibration to raw confidence score."""
if len(self.data) < self.config.min_samples:
return raw_confidence
# Simple histogram-based calibration
bins = self.config.bins
bin_idx = min(int(raw_confidence * bins), bins - 1)
bin_samples = [(d.confidence, d.was_correct) for d in self.data
if min(int(d.confidence * bins), bins - 1) == bin_idx]
if len(bin_samples) < 10:
return raw_confidence
actual_accuracy = sum(1 for _, correct in bin_samples if correct) / len(bin_samples)
return actual_accuracy
def get_stats(self) -> Dict:
curve = self.get_calibration_curve()
return {
'sample_count': len(self.data),
'ece': curve.get('ece', None),
'is_calibrated': len(self.data) >= self.config.min_samples,
'calibration_curve': curve if 'warning' not in curve else None
}
_calibrator: Optional[ConfidenceCalibrator] = None
def get_calibrator(config: Optional[CalibrationConfig] = None) -> ConfidenceCalibrator: global _calibrator if _calibrator is None: _calibrator = ConfidenceCalibrator(config) return _calibrator