#!/usr/bin/env python3 """ Batch Classification Runner
High-performance batch processing for classifying large document sets. Features:
- Resumable processing with checkpoint files
- Parallel document processing
- Detailed progress reporting
- Multiple output formats
- Statistics and analytics
Usage: # Classify all docs in repository python batch_classify.py
# Resume from checkpoint
python batch_classify.py --resume
# Process specific directory with workers
python batch_classify.py --path docs/ --workers 4
# Generate detailed report
python batch_classify.py --report
"""
import argparse import json import sys import time import os from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import List, Dict, Optional, Set from dataclasses import dataclass, field, asdict import logging
Add module path
sys.path.insert(0, str(Path(file).parent))
from core.models import Document, ClassificationResult, ApprovalType from core.orchestrator import create_default_orchestrator, MoEOrchestrator
Set up logging
logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(name)
@dataclass class BatchState: """State for resumable batch processing.""" started_at: str total_files: int processed_files: int completed: Set[str] = field(default_factory=set) failed: Dict[str, str] = field(default_factory=dict) last_checkpoint: str = ""
def to_dict(self) -> Dict:
return {
'started_at': self.started_at,
'total_files': self.total_files,
'processed_files': self.processed_files,
'completed': list(self.completed),
'failed': self.failed,
'last_checkpoint': self.last_checkpoint
}
@classmethod
def from_dict(cls, data: Dict) -> 'BatchState':
return cls(
started_at=data['started_at'],
total_files=data['total_files'],
processed_files=data['processed_files'],
completed=set(data.get('completed', [])),
failed=data.get('failed', {}),
last_checkpoint=data.get('last_checkpoint', '')
)
@dataclass class BatchStats: """Statistics from batch classification.""" total: int = 0 processed: int = 0 successful: int = 0 failed: int = 0 auto_approved: int = 0 judge_approved: int = 0 escalated: int = 0 by_classification: Dict[str, int] = field(default_factory=dict) by_confidence: Dict[str, int] = field(default_factory=lambda: { 'high': 0, # >= 0.80 'medium': 0, # 0.60 - 0.80 'low': 0 # < 0.60 }) total_time_seconds: float = 0.0 avg_confidence: float = 0.0
class BatchRunner: """High-performance batch classification runner."""
CHECKPOINT_FILE = '.moe_batch_checkpoint.json'
RESULTS_FILE = 'classification_results.json'
SUPPORTED_EXTENSIONS = {'.md', '.markdown'}
def __init__(self, args: argparse.Namespace):
self.args = args
self.orchestrator: Optional[MoEOrchestrator] = None
self.state: Optional[BatchState] = None
self.stats = BatchStats()
self.results: List[ClassificationResult] = []
# Paths
self.base_path = Path(args.path) if args.path else self._find_docs_root()
self.output_dir = Path(args.output_dir) if args.output_dir else self.base_path
def _find_docs_root(self) -> Path:
"""Find the docs root directory."""
# Look for common doc directories
candidates = [
Path.cwd() / 'docs',
Path.cwd().parent / 'docs',
Path.cwd()
]
for candidate in candidates:
if candidate.exists() and candidate.is_dir():
return candidate
return Path.cwd()
def run(self) -> int:
"""Execute batch classification."""
try:
start_time = time.time()
# Initialize
logger.info("Initializing batch classifier...")
self.orchestrator = create_default_orchestrator()
# Collect files
files = self._collect_files()
if not files:
logger.warning("No files found to classify")
return 0
# Load or create state
if self.args.resume and self._checkpoint_exists():
self.state = self._load_checkpoint()
files = self._filter_completed(files)
logger.info(f"Resuming: {len(files)} remaining of {self.state.total_files}")
else:
self.state = BatchState(
started_at=datetime.now(timezone.utc).isoformat(),
total_files=len(files),
processed_files=0
)
self.stats.total = len(files)
logger.info(f"Processing {len(files)} files...")
# Process in batches
if self.args.workers > 1:
self._process_parallel(files)
else:
self._process_sequential(files)
# Finalize
self.stats.total_time_seconds = time.time() - start_time
self._calculate_final_stats()
# Generate outputs
self._save_results()
if self.args.report:
self._generate_report()
# Show summary
self._print_summary()
# Clean up checkpoint on success
if self.stats.failed == 0:
self._cleanup_checkpoint()
return 0 if self.stats.failed == 0 else 1
except KeyboardInterrupt:
logger.info("\nBatch processing interrupted - saving checkpoint...")
self._save_checkpoint()
return 130
except Exception as e:
logger.error(f"Batch processing failed: {e}")
self._save_checkpoint()
if self.args.verbose:
import traceback
traceback.print_exc()
return 1
def _collect_files(self) -> List[Path]:
"""Collect all files to process."""
files = []
for ext in self.SUPPORTED_EXTENSIONS:
files.extend(self.base_path.rglob(f'*{ext}'))
# Apply exclusions
exclusions = {'node_modules', '.git', '__pycache__', '.venv', 'venv'}
if self.args.exclude:
exclusions.update(self.args.exclude.split(','))
files = [
f for f in files
if not any(ex in str(f) for ex in exclusions)
]
files.sort()
return files
def _filter_completed(self, files: List[Path]) -> List[Path]:
"""Filter out already completed files."""
return [f for f in files if str(f) not in self.state.completed]
def _process_sequential(self, files: List[Path]):
"""Process files sequentially."""
for i, file_path in enumerate(files, 1):
self._process_file(file_path, i, len(files))
# Checkpoint every N files
if i % self.args.checkpoint_interval == 0:
self._save_checkpoint()
def _process_parallel(self, files: List[Path]):
"""Process files in parallel."""
with ThreadPoolExecutor(max_workers=self.args.workers) as executor:
futures = {
executor.submit(self._classify_file, f): f
for f in files
}
for i, future in enumerate(as_completed(futures), 1):
file_path = futures[future]
try:
result = future.result()
self._record_result(file_path, result)
except Exception as e:
self._record_error(file_path, str(e))
self._show_progress(i, len(files))
if i % self.args.checkpoint_interval == 0:
self._save_checkpoint()
def _process_file(self, file_path: Path, current: int, total: int):
"""Process a single file."""
try:
result = self._classify_file(file_path)
self._record_result(file_path, result)
except Exception as e:
self._record_error(file_path, str(e))
self._show_progress(current, total)
def _classify_file(self, file_path: Path) -> ClassificationResult:
"""Classify a single file."""
doc = Document.from_path(file_path)
return self.orchestrator.classify(doc)
def _record_result(self, file_path: Path, result: ClassificationResult):
"""Record a successful classification result."""
self.results.append(result)
self.state.completed.add(str(file_path))
self.state.processed_files += 1
self.stats.processed += 1
self.stats.successful += 1
# Update stats
approval = result.result.approval_type
if approval == ApprovalType.AUTO_APPROVED:
self.stats.auto_approved += 1
elif approval == ApprovalType.JUDGE_APPROVED:
self.stats.judge_approved += 1
elif approval == ApprovalType.ESCALATED:
self.stats.escalated += 1
cls = result.result.classification or 'unknown'
self.stats.by_classification[cls] = self.stats.by_classification.get(cls, 0) + 1
conf = result.result.confidence
if conf >= 0.80:
self.stats.by_confidence['high'] += 1
elif conf >= 0.60:
self.stats.by_confidence['medium'] += 1
else:
self.stats.by_confidence['low'] += 1
def _record_error(self, file_path: Path, error: str):
"""Record a failed classification."""
self.state.failed[str(file_path)] = error
self.state.processed_files += 1
self.stats.processed += 1
self.stats.failed += 1
if self.args.verbose:
logger.error(f"Failed: {file_path}: {error}")
def _show_progress(self, current: int, total: int):
"""Show progress bar."""
if self.args.quiet or total == 0:
return
pct = current / total
filled = int(50 * pct)
bar = '█' * filled + '░' * (50 - filled)
print(f"\r[{bar}] {current}/{total} ({pct:.1%}) | "
f"✓{self.stats.successful} ✗{self.stats.failed}", end='', flush=True)
if current == total:
print()
def _calculate_final_stats(self):
"""Calculate final statistics."""
if self.results:
total_conf = sum(r.result.confidence for r in self.results)
self.stats.avg_confidence = total_conf / len(self.results)
def _save_checkpoint(self):
"""Save checkpoint for resume."""
self.state.last_checkpoint = datetime.now(timezone.utc).isoformat()
checkpoint_path = self.output_dir / self.CHECKPOINT_FILE
with open(checkpoint_path, 'w') as f:
json.dump(self.state.to_dict(), f, indent=2)
logger.debug(f"Checkpoint saved: {checkpoint_path}")
def _load_checkpoint(self) -> BatchState:
"""Load checkpoint from file."""
checkpoint_path = self.output_dir / self.CHECKPOINT_FILE
with open(checkpoint_path) as f:
data = json.load(f)
return BatchState.from_dict(data)
def _checkpoint_exists(self) -> bool:
"""Check if checkpoint file exists."""
return (self.output_dir / self.CHECKPOINT_FILE).exists()
def _cleanup_checkpoint(self):
"""Remove checkpoint file after successful completion."""
checkpoint_path = self.output_dir / self.CHECKPOINT_FILE
if checkpoint_path.exists():
checkpoint_path.unlink()
logger.debug("Checkpoint cleaned up")
def _save_results(self):
"""Save classification results."""
results_path = self.output_dir / self.RESULTS_FILE
output = {
'metadata': {
'generated_at': datetime.now(timezone.utc).isoformat(),
'base_path': str(self.base_path),
'total_files': self.stats.total,
'processing_time_seconds': round(self.stats.total_time_seconds, 2)
},
'stats': {
'processed': self.stats.processed,
'successful': self.stats.successful,
'failed': self.stats.failed,
'auto_approved': self.stats.auto_approved,
'judge_approved': self.stats.judge_approved,
'escalated': self.stats.escalated,
'avg_confidence': round(self.stats.avg_confidence, 3),
'by_classification': self.stats.by_classification,
'by_confidence': self.stats.by_confidence
},
'results': [r.to_dict() for r in self.results],
'errors': self.state.failed if self.state else {}
}
with open(results_path, 'w') as f:
json.dump(output, f, indent=2, default=str)
logger.info(f"Results saved to: {results_path}")
def _generate_report(self):
"""Generate detailed markdown report."""
report_path = self.output_dir / 'classification_report.md'
with open(report_path, 'w') as f:
f.write("# MoE Document Classification Report\n\n")
f.write(f"**Generated:** {datetime.now(timezone.utc).isoformat()}\n")
f.write(f"**Base Path:** `{self.base_path}`\n\n")
# Executive Summary
f.write("## Executive Summary\n\n")
approval_rate = (self.stats.auto_approved + self.stats.judge_approved) / max(1, self.stats.successful)
f.write(f"- **Total Documents:** {self.stats.total}\n")
f.write(f"- **Successfully Classified:** {self.stats.successful}\n")
f.write(f"- **Approval Rate:** {approval_rate:.1%}\n")
f.write(f"- **Average Confidence:** {self.stats.avg_confidence:.1%}\n")
f.write(f"- **Processing Time:** {self.stats.total_time_seconds:.1f}s\n\n")
# Approval Status
f.write("## Approval Status\n\n")
f.write("| Status | Count | Percentage |\n")
f.write("|--------|-------|------------|\n")
total = max(1, self.stats.successful)
f.write(f"| Auto-Approved | {self.stats.auto_approved} | {self.stats.auto_approved/total:.1%} |\n")
f.write(f"| Judge-Approved | {self.stats.judge_approved} | {self.stats.judge_approved/total:.1%} |\n")
f.write(f"| Escalated | {self.stats.escalated} | {self.stats.escalated/total:.1%} |\n\n")
# Classification Distribution
f.write("## Classification Distribution\n\n")
f.write("| Type | Count | Percentage |\n")
f.write("|------|-------|------------|\n")
for cls, count in sorted(self.stats.by_classification.items(), key=lambda x: -x[1]):
f.write(f"| {cls} | {count} | {count/total:.1%} |\n")
f.write("\n")
# Confidence Distribution
f.write("## Confidence Distribution\n\n")
f.write("| Level | Count | Percentage |\n")
f.write("|-------|-------|------------|\n")
level_labels = {'high': '≥80%', 'medium': '60-80%', 'low': '<60%'}
for level, count in self.stats.by_confidence.items():
f.write(f"| {level.title()} ({level_labels[level]}) | {count} | {count/total:.1%} |\n")
f.write("\n")
# Escalated Files
escalated = [r for r in self.results if r.result.approval_type == ApprovalType.ESCALATED]
if escalated:
f.write("## Escalated Files (Need Review)\n\n")
for r in escalated[:50]: # Limit to 50
f.write(f"### `{r.document_path}`\n")
f.write(f"- **Proposed:** {r.result.classification}\n")
f.write(f"- **Confidence:** {r.result.confidence:.1%}\n")
f.write(f"- **Reason:** {r.result.escalation_reason}\n\n")
if len(escalated) > 50:
f.write(f"*...and {len(escalated) - 50} more escalated files*\n\n")
# Errors
if self.state and self.state.failed:
f.write("## Errors\n\n")
for path, error in list(self.state.failed.items())[:20]:
f.write(f"- `{path}`: {error}\n")
if len(self.state.failed) > 20:
f.write(f"\n*...and {len(self.state.failed) - 20} more errors*\n")
logger.info(f"Report saved to: {report_path}")
def _print_summary(self):
"""Print summary to console."""
print("\n" + "="*70)
print("BATCH CLASSIFICATION COMPLETE")
print("="*70)
approval_rate = (self.stats.auto_approved + self.stats.judge_approved) / max(1, self.stats.successful)
files_per_sec = self.stats.processed / max(0.1, self.stats.total_time_seconds)
print(f"\nProcessed: {self.stats.processed} files in {self.stats.total_time_seconds:.1f}s ({files_per_sec:.1f} files/sec)")
print(f"Successful: {self.stats.successful} | Failed: {self.stats.failed}")
print(f"Approval Rate: {approval_rate:.1%} | Avg Confidence: {self.stats.avg_confidence:.1%}")
print(f"\nApproval Status:")
print(f" Auto-approved: {self.stats.auto_approved}")
print(f" Judge-approved: {self.stats.judge_approved}")
print(f" Escalated: {self.stats.escalated}")
print(f"\nTop Classifications:")
for cls, count in sorted(self.stats.by_classification.items(), key=lambda x: -x[1])[:5]:
print(f" {cls:15}: {count}")
print()
def create_parser() -> argparse.ArgumentParser: """Create argument parser.""" parser = argparse.ArgumentParser( description='Batch Document Classification Runner', formatter_class=argparse.RawDescriptionHelpFormatter )
parser.add_argument(
'--path',
type=str,
help='Base path to process (default: auto-detect docs/)'
)
parser.add_argument(
'--output-dir',
type=str,
help='Output directory for results (default: same as path)'
)
parser.add_argument(
'--workers',
type=int,
default=1,
help='Number of parallel workers (default: 1)'
)
parser.add_argument(
'--resume',
action='store_true',
help='Resume from checkpoint'
)
parser.add_argument(
'--report',
action='store_true',
help='Generate detailed markdown report'
)
parser.add_argument(
'--exclude',
type=str,
help='Comma-separated patterns to exclude'
)
parser.add_argument(
'--checkpoint-interval',
type=int,
default=100,
help='Save checkpoint every N files (default: 100)'
)
parser.add_argument(
'-v', '--verbose',
action='store_true',
help='Verbose output'
)
parser.add_argument(
'-q', '--quiet',
action='store_true',
help='Suppress progress output'
)
return parser
def main(): """Main entry point.""" parser = create_parser() args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
elif args.quiet:
logging.getLogger().setLevel(logging.WARNING)
runner = BatchRunner(args)
return runner.run()
if name == 'main': sys.exit(main())