#!/usr/bin/env python3 """ Classification Learning System for MoE.
Implements a feedback loop that learns from classification outcomes:
- Tracks analyst accuracy over time
- Dynamically adjusts analyst weights based on performance
- Records classification history for pattern analysis
- Enables confirmation/correction of classifications
Features:
- SQLite-backed persistent storage
- Analyst performance tracking
- Dynamic weight calculation
- Bootstrap from existing classifications """
import sqlite3 import json import logging from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Tuple from contextlib import contextmanager
logger = logging.getLogger(name)
@dataclass class LearningConfig: """Configuration for the learning system.""" db_path: str = "moe_learning.db" min_samples_for_weight: int = 10 # Minimum samples before adjusting weights weight_smoothing: float = 0.5 # Balance between accuracy and default weight decay_factor: float = 0.99 # Slight decay for old samples max_history_days: int = 90 # Keep 90 days of history
@dataclass class AnalystPerformance: """Performance metrics for an analyst.""" analyst: str correct_count: int total_count: int accuracy: float dynamic_weight: float last_updated: datetime
@dataclass class ClassificationOutcome: """Record of a classification and its outcome.""" id: int document_path: str predicted_type: str actual_type: Optional[str] confidence: float analyst_votes: List[Dict] created_at: datetime confirmed_at: Optional[datetime]
class ClassificationLearner: """ Learns from historical classification outcomes. Tracks analyst accuracy and adjusts weights dynamically. """
def __init__(self, config: Optional[LearningConfig] = None):
self.config = config or LearningConfig()
self._init_database()
@contextmanager
def _get_connection(self):
"""Get database connection with context manager."""
conn = sqlite3.connect(self.config.db_path)
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_database(self):
"""Initialize learning database tables."""
with self._get_connection() as conn:
# Classification outcomes table
conn.execute("""
CREATE TABLE IF NOT EXISTS classification_outcomes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
document_path TEXT NOT NULL,
document_hash TEXT,
predicted_type TEXT NOT NULL,
actual_type TEXT,
confidence REAL NOT NULL,
agreement_ratio REAL,
analyst_votes TEXT NOT NULL,
judge_decisions TEXT,
approval_type TEXT,
confirmed_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(document_path, created_at)
)
""")
# Analyst accuracy tracking
conn.execute("""
CREATE TABLE IF NOT EXISTS analyst_accuracy (
analyst TEXT PRIMARY KEY,
correct_count INTEGER DEFAULT 0,
total_count INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
dynamic_weight REAL DEFAULT 1.0,
last_correct_at TIMESTAMP,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Type-specific accuracy (analyst performance per document type)
conn.execute("""
CREATE TABLE IF NOT EXISTS analyst_type_accuracy (
analyst TEXT NOT NULL,
document_type TEXT NOT NULL,
correct_count INTEGER DEFAULT 0,
total_count INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (analyst, document_type)
)
""")
# Learning events log
conn.execute("""
CREATE TABLE IF NOT EXISTS learning_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL,
details TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_outcomes_path
ON classification_outcomes(document_path)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_outcomes_predicted
ON classification_outcomes(predicted_type)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_outcomes_confirmed
ON classification_outcomes(confirmed_at)
""")
logger.info(f"Learning database initialized at {self.config.db_path}")
def record_classification(self,
document_path: str,
predicted_type: str,
confidence: float,
agreement_ratio: float,
analyst_votes: List[Dict],
judge_decisions: Optional[List[Dict]] = None,
approval_type: Optional[str] = None,
document_hash: Optional[str] = None):
"""
Record a classification for future learning.
Args:
document_path: Path to classified document
predicted_type: The classification result
confidence: Classification confidence
agreement_ratio: Analyst agreement ratio
analyst_votes: List of analyst vote dictionaries
judge_decisions: Optional list of judge decisions
approval_type: Type of approval (AUTO, JUDGE, etc.)
document_hash: Optional content hash for deduplication
"""
with self._get_connection() as conn:
try:
conn.execute("""
INSERT INTO classification_outcomes
(document_path, document_hash, predicted_type, confidence,
agreement_ratio, analyst_votes, judge_decisions, approval_type)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
str(document_path),
document_hash,
predicted_type,
confidence,
agreement_ratio,
json.dumps(analyst_votes),
json.dumps(judge_decisions) if judge_decisions else None,
approval_type
))
# Log event
conn.execute("""
INSERT INTO learning_events (event_type, details)
VALUES ('classification_recorded', ?)
""", (json.dumps({
"path": str(document_path),
"type": predicted_type,
"confidence": confidence
}),))
logger.debug(f"Recorded classification: {document_path} -> {predicted_type}")
except sqlite3.IntegrityError:
logger.debug(f"Classification already recorded for {document_path}")
def confirm_classification(self,
document_path: str,
actual_type: str,
outcome_id: Optional[int] = None) -> bool:
"""
Confirm or correct a classification - triggers learning.
Args:
document_path: Path to the document
actual_type: The actual/correct type (may differ from predicted)
outcome_id: Optional specific outcome ID to confirm
Returns:
True if confirmation was recorded, False otherwise
"""
with self._get_connection() as conn:
# Find the most recent unconfirmed classification
if outcome_id:
query = """
SELECT id, predicted_type, analyst_votes FROM classification_outcomes
WHERE id = ? AND actual_type IS NULL
"""
params = (outcome_id,)
else:
query = """
SELECT id, predicted_type, analyst_votes FROM classification_outcomes
WHERE document_path = ? AND actual_type IS NULL
ORDER BY created_at DESC LIMIT 1
"""
params = (str(document_path),)
cursor = conn.execute(query, params)
row = cursor.fetchone()
if not row:
logger.warning(f"No unconfirmed classification found for {document_path}")
return False
outcome_id, predicted_type, votes_json = row['id'], row['predicted_type'], row['analyst_votes']
votes = json.loads(votes_json)
was_correct = predicted_type == actual_type
# Update outcome with confirmation
conn.execute("""
UPDATE classification_outcomes
SET actual_type = ?, confirmed_at = ?
WHERE id = ?
""", (actual_type, datetime.now(timezone.utc).isoformat(), outcome_id))
# Update analyst accuracy for each vote
for vote in votes:
analyst = vote.get('agent') or vote.get('analyst')
vote_correct = vote.get('classification') == actual_type
# Update overall accuracy
conn.execute("""
INSERT INTO analyst_accuracy (analyst, correct_count, total_count, last_updated)
VALUES (?, ?, 1, ?)
ON CONFLICT(analyst) DO UPDATE SET
correct_count = correct_count + ?,
total_count = total_count + 1,
accuracy = CAST(correct_count + ? AS REAL) / (total_count + 1),
last_correct_at = CASE WHEN ? = 1 THEN ? ELSE last_correct_at END,
last_updated = ?
""", (
analyst,
1 if vote_correct else 0,
datetime.now(timezone.utc).isoformat(),
1 if vote_correct else 0,
1 if vote_correct else 0,
1 if vote_correct else 0,
datetime.now(timezone.utc).isoformat(),
datetime.now(timezone.utc).isoformat()
))
# Update type-specific accuracy
conn.execute("""
INSERT INTO analyst_type_accuracy
(analyst, document_type, correct_count, total_count, last_updated)
VALUES (?, ?, ?, 1, ?)
ON CONFLICT(analyst, document_type) DO UPDATE SET
correct_count = correct_count + ?,
total_count = total_count + 1,
accuracy = CAST(correct_count + ? AS REAL) / (total_count + 1),
last_updated = ?
""", (
analyst,
actual_type,
1 if vote_correct else 0,
datetime.now(timezone.utc).isoformat(),
1 if vote_correct else 0,
1 if vote_correct else 0,
datetime.now(timezone.utc).isoformat()
))
# Recalculate dynamic weights
self._update_dynamic_weights(conn)
# Log learning event
conn.execute("""
INSERT INTO learning_events (event_type, details)
VALUES ('classification_confirmed', ?)
""", (json.dumps({
"path": str(document_path),
"predicted": predicted_type,
"actual": actual_type,
"correct": was_correct
}),))
logger.info(f"Confirmed classification: {document_path} "
f"(predicted={predicted_type}, actual={actual_type}, correct={was_correct})")
return True
def _update_dynamic_weights(self, conn: sqlite3.Connection):
"""Update dynamic weights based on accuracy."""
cursor = conn.execute("""
SELECT analyst, accuracy, total_count FROM analyst_accuracy
WHERE total_count >= ?
""", (self.config.min_samples_for_weight,))
for row in cursor.fetchall():
analyst, accuracy, total_count = row['analyst'], row['accuracy'], row['total_count']
# Calculate confidence factor based on sample size
confidence_factor = min(1.0, total_count / 100)
# Dynamic weight: blend accuracy with default (1.0)
dynamic_weight = (
accuracy * confidence_factor +
(1 - confidence_factor) * self.config.weight_smoothing
)
conn.execute("""
UPDATE analyst_accuracy
SET dynamic_weight = ?
WHERE analyst = ?
""", (dynamic_weight, analyst))
def get_analyst_weights(self) -> Dict[str, float]:
"""
Get dynamic weights based on analyst accuracy.
Returns:
Dict mapping analyst names to their weights
"""
weights = {}
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT analyst, dynamic_weight, total_count FROM analyst_accuracy
WHERE total_count >= ?
""", (self.config.min_samples_for_weight,))
for row in cursor.fetchall():
weights[row['analyst']] = row['dynamic_weight']
return weights
def get_analyst_performance(self, analyst: Optional[str] = None) -> List[AnalystPerformance]:
"""Get performance metrics for analysts."""
performances = []
with self._get_connection() as conn:
if analyst:
query = "SELECT * FROM analyst_accuracy WHERE analyst = ?"
params = (analyst,)
else:
query = "SELECT * FROM analyst_accuracy ORDER BY accuracy DESC"
params = ()
cursor = conn.execute(query, params)
for row in cursor.fetchall():
performances.append(AnalystPerformance(
analyst=row['analyst'],
correct_count=row['correct_count'],
total_count=row['total_count'],
accuracy=row['accuracy'],
dynamic_weight=row['dynamic_weight'],
last_updated=datetime.fromisoformat(row['last_updated']) if row['last_updated'] else None
))
return performances
def get_type_accuracy(self, analyst: str) -> Dict[str, float]:
"""Get analyst accuracy by document type."""
type_accuracy = {}
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT document_type, accuracy, total_count
FROM analyst_type_accuracy
WHERE analyst = ? AND total_count >= 5
""", (analyst,))
for row in cursor.fetchall():
type_accuracy[row['document_type']] = row['accuracy']
return type_accuracy
def find_similar_by_path(self, document_path: str, limit: int = 5) -> List[Dict]:
"""Find similar documents by path pattern."""
path = Path(document_path)
parent = str(path.parent)
suffix = path.suffix
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT document_path, predicted_type, actual_type, confidence
FROM classification_outcomes
WHERE (document_path LIKE ? OR document_path LIKE ?)
AND actual_type IS NOT NULL
ORDER BY created_at DESC
LIMIT ?
""", (f"%{parent}%", f"%{suffix}", limit))
return [dict(row) for row in cursor.fetchall()]
def get_classification_history(self,
document_path: Optional[str] = None,
limit: int = 100) -> List[ClassificationOutcome]:
"""Get classification history."""
outcomes = []
with self._get_connection() as conn:
if document_path:
query = """
SELECT * FROM classification_outcomes
WHERE document_path = ?
ORDER BY created_at DESC
LIMIT ?
"""
params = (str(document_path), limit)
else:
query = """
SELECT * FROM classification_outcomes
ORDER BY created_at DESC
LIMIT ?
"""
params = (limit,)
cursor = conn.execute(query, params)
for row in cursor.fetchall():
outcomes.append(ClassificationOutcome(
id=row['id'],
document_path=row['document_path'],
predicted_type=row['predicted_type'],
actual_type=row['actual_type'],
confidence=row['confidence'],
analyst_votes=json.loads(row['analyst_votes']),
created_at=datetime.fromisoformat(row['created_at']),
confirmed_at=datetime.fromisoformat(row['confirmed_at']) if row['confirmed_at'] else None
))
return outcomes
def get_pending_confirmations(self, limit: int = 50) -> List[ClassificationOutcome]:
"""Get classifications awaiting confirmation."""
outcomes = []
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM classification_outcomes
WHERE actual_type IS NULL
ORDER BY created_at DESC
LIMIT ?
""", (limit,))
for row in cursor.fetchall():
outcomes.append(ClassificationOutcome(
id=row['id'],
document_path=row['document_path'],
predicted_type=row['predicted_type'],
actual_type=None,
confidence=row['confidence'],
analyst_votes=json.loads(row['analyst_votes']),
created_at=datetime.fromisoformat(row['created_at']),
confirmed_at=None
))
return outcomes
def get_stats(self) -> Dict:
"""Get learning system statistics."""
with self._get_connection() as conn:
total_classifications = conn.execute(
"SELECT COUNT(*) FROM classification_outcomes"
).fetchone()[0]
confirmed_count = conn.execute(
"SELECT COUNT(*) FROM classification_outcomes WHERE actual_type IS NOT NULL"
).fetchone()[0]
correct_count = conn.execute(
"SELECT COUNT(*) FROM classification_outcomes WHERE predicted_type = actual_type"
).fetchone()[0]
analyst_count = conn.execute(
"SELECT COUNT(*) FROM analyst_accuracy"
).fetchone()[0]
avg_accuracy = conn.execute(
"SELECT AVG(accuracy) FROM analyst_accuracy WHERE total_count >= ?"
, (self.config.min_samples_for_weight,)).fetchone()[0]
return {
"total_classifications": total_classifications,
"confirmed_count": confirmed_count,
"pending_count": total_classifications - confirmed_count,
"correct_count": correct_count,
"overall_accuracy": correct_count / confirmed_count if confirmed_count > 0 else 0,
"tracked_analysts": analyst_count,
"average_analyst_accuracy": avg_accuracy or 0,
"db_path": self.config.db_path
}
def bootstrap_from_frontmatter(self, docs_path: str):
"""
Bootstrap learning from documents with existing type metadata.
Uses frontmatter type declarations as ground truth.
"""
import yaml
import re
docs_dir = Path(docs_path)
bootstrapped = 0
for md_file in docs_dir.rglob("*.md"):
try:
content = md_file.read_text(encoding='utf-8')
# Parse frontmatter
if content.strip().startswith('---'):
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
if match:
frontmatter = yaml.safe_load(match.group(1)) or {}
declared_type = (
frontmatter.get('component_type') or
frontmatter.get('type') or
frontmatter.get('doc_type')
)
if declared_type:
# Record as confirmed classification
self.record_classification(
document_path=str(md_file),
predicted_type=declared_type,
confidence=0.95,
agreement_ratio=1.0,
analyst_votes=[{
"agent": "frontmatter_bootstrap",
"classification": declared_type,
"confidence": 0.95
}],
approval_type="BOOTSTRAP"
)
# Immediately confirm
with self._get_connection() as conn:
conn.execute("""
UPDATE classification_outcomes
SET actual_type = ?, confirmed_at = ?
WHERE document_path = ? AND actual_type IS NULL
""", (declared_type, datetime.now(timezone.utc).isoformat(), str(md_file)))
bootstrapped += 1
except Exception as e:
logger.warning(f"Failed to bootstrap {md_file}: {e}")
logger.info(f"Bootstrapped {bootstrapped} documents from frontmatter")
return bootstrapped
Singleton instance
_learner: Optional[ClassificationLearner] = None
def get_learner(config: Optional[LearningConfig] = None) -> ClassificationLearner: """Get or create singleton learner instance.""" global _learner if _learner is None: _learner = ClassificationLearner(config) return _learner
if name == "main": # Test the learning system logging.basicConfig(level=logging.INFO)
learner = ClassificationLearner(LearningConfig(db_path=":memory:"))
# Test recording
learner.record_classification(
document_path="test/doc.md",
predicted_type="guide",
confidence=0.85,
agreement_ratio=0.8,
analyst_votes=[
{"agent": "structural", "classification": "guide", "confidence": 0.9},
{"agent": "content", "classification": "guide", "confidence": 0.8},
{"agent": "metadata", "classification": "reference", "confidence": 0.7}
]
)
# Test confirmation
learner.confirm_classification("test/doc.md", "guide")
# Check stats
stats = learner.get_stats()
print(f"Stats: {json.dumps(stats, indent=2)}")
# Check analyst performance
performances = learner.get_analyst_performance()
for p in performances:
print(f"Analyst {p.analyst}: accuracy={p.accuracy:.1%}, weight={p.dynamic_weight:.2f}")