Skip to main content

#!/usr/bin/env python3 """ CODITECT Circuit Breaker Service

Implements the Circuit Breaker pattern to prevent cascading failures when calling agents or external services.

States:

  • CLOSED: Normal operation, requests pass through
  • OPEN: Too many failures, requests fail immediately (fast-fail)
  • HALF_OPEN: Testing recovery, limited requests allowed

Key Features:

  • Configurable failure thresholds and timeouts
  • Per-agent circuit breakers with automatic creation
  • Integration with AgentDiscoveryService for fallback routing
  • Async support for non-blocking operations
  • Metrics tracking for monitoring
  • Optional pybreaker compatibility

Part of Track H.2: Inter-Agent Communication Infrastructure

Usage: from circuit_breaker import CircuitBreaker, AgentCircuitBreaker

# Simple usage
breaker = CircuitBreaker("my-service")
result = await breaker.call(my_async_function, arg1, arg2)

# With decorator
@circuit_breaker("api-service")
async def call_api():
return await fetch_data()

# Agent circuit breaker with fallback
agent_breaker = AgentCircuitBreaker(discovery_service)
result = await agent_breaker.call_agent("agent-1", task_func, task_data)

Author: CODITECT Framework Version: 1.0.0 Created: January 8, 2026 """

import os import asyncio import logging import time from enum import Enum from typing import ( Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union ) from dataclasses import dataclass, field from datetime import datetime, timedelta from functools import wraps from collections import deque import threading

Configure logging

logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)

Type variable for generic async functions

T = TypeVar('T')

=============================================================================

Enums

=============================================================================

class CircuitState(Enum): """Circuit breaker states""" CLOSED = "closed" # Normal operation OPEN = "open" # Failing fast, rejecting requests HALF_OPEN = "half_open" # Testing recovery

class FailureType(Enum): """Types of failures to track""" EXCEPTION = "exception" TIMEOUT = "timeout" REJECTION = "rejection" # Circuit open rejection CUSTOM = "custom"

=============================================================================

Exceptions

=============================================================================

class CircuitBreakerError(Exception): """Base exception for circuit breaker errors""" pass

class CircuitOpenError(CircuitBreakerError): """Raised when circuit is open and request is rejected""" def init(self, breaker_name: str, time_remaining: float = 0): self.breaker_name = breaker_name self.time_remaining = time_remaining super().init( f"Circuit breaker '{breaker_name}' is OPEN. " f"Retry in {time_remaining:.1f}s" )

class CircuitBreakerConfigError(CircuitBreakerError): """Raised for configuration errors""" pass

=============================================================================

Data Classes

=============================================================================

@dataclass class CircuitBreakerConfig: """Configuration for a circuit breaker"""

# Failure thresholds
fail_max: int = 5 # Open circuit after this many failures
fail_window_seconds: float = 60.0 # Time window for counting failures

# Recovery settings
recovery_timeout: float = 60.0 # How long to stay open before testing
half_open_max_calls: int = 3 # Max test calls in half-open state
success_threshold: int = 2 # Successes needed to close circuit

# Timeouts
call_timeout: Optional[float] = 30.0 # Default timeout for calls (None = no timeout)

# Exceptions
exclude_exceptions: List[Type[Exception]] = field(default_factory=list)
include_exceptions: Optional[List[Type[Exception]]] = None # None = all

# Listeners
on_open: Optional[Callable[["CircuitBreaker"], None]] = None
on_close: Optional[Callable[["CircuitBreaker"], None]] = None
on_half_open: Optional[Callable[["CircuitBreaker"], None]] = None

def __post_init__(self):
if self.fail_max < 1:
raise CircuitBreakerConfigError("fail_max must be at least 1")
if self.recovery_timeout < 0:
raise CircuitBreakerConfigError("recovery_timeout must be non-negative")
if self.half_open_max_calls < 1:
raise CircuitBreakerConfigError("half_open_max_calls must be at least 1")

@classmethod
def from_env(cls, prefix: str = "CIRCUIT_BREAKER") -> "CircuitBreakerConfig":
"""Create config from environment variables"""
return cls(
fail_max=int(os.environ.get(f"{prefix}_FAIL_MAX", "5")),
fail_window_seconds=float(os.environ.get(f"{prefix}_FAIL_WINDOW", "60")),
recovery_timeout=float(os.environ.get(f"{prefix}_RECOVERY_TIMEOUT", "60")),
half_open_max_calls=int(os.environ.get(f"{prefix}_HALF_OPEN_MAX", "3")),
success_threshold=int(os.environ.get(f"{prefix}_SUCCESS_THRESHOLD", "2")),
call_timeout=float(os.environ.get(f"{prefix}_CALL_TIMEOUT", "30")) or None
)

@dataclass class FailureRecord: """Record of a single failure""" timestamp: datetime failure_type: FailureType exception_type: Optional[str] = None message: Optional[str] = None

@dataclass class CircuitMetrics: """Metrics for a circuit breaker""" name: str state: CircuitState failure_count: int success_count: int rejection_count: int last_failure_time: Optional[datetime] last_success_time: Optional[datetime] last_state_change: datetime total_calls: int failure_rate: float # 0.0 to 1.0

def to_dict(self) -> Dict:
return {
"name": self.name,
"state": self.state.value,
"failure_count": self.failure_count,
"success_count": self.success_count,
"rejection_count": self.rejection_count,
"last_failure_time": self.last_failure_time.isoformat() if self.last_failure_time else None,
"last_success_time": self.last_success_time.isoformat() if self.last_success_time else None,
"last_state_change": self.last_state_change.isoformat(),
"total_calls": self.total_calls,
"failure_rate": round(self.failure_rate, 4)
}

=============================================================================

Circuit Breaker Implementation

=============================================================================

class CircuitBreaker: """ Circuit breaker implementation with state machine.

States:
- CLOSED: Normal operation. Failures are counted.
- OPEN: Circuit is open. Requests fail immediately.
- HALF_OPEN: Testing recovery. Limited requests allowed.

Transitions:
- CLOSED -> OPEN: When failures exceed threshold
- OPEN -> HALF_OPEN: After recovery timeout
- HALF_OPEN -> CLOSED: When test requests succeed
- HALF_OPEN -> OPEN: When test requests fail
"""

def __init__(
self,
name: str,
config: Optional[CircuitBreakerConfig] = None
):
self.name = name
self.config = config or CircuitBreakerConfig()

# State
self._state = CircuitState.CLOSED
self._state_changed_at = datetime.utcnow()

# Failure tracking (sliding window)
self._failures: deque = deque()
self._lock = threading.RLock()

# Half-open state tracking
self._half_open_successes = 0
self._half_open_calls = 0

# Metrics
self._total_calls = 0
self._total_successes = 0
self._total_failures = 0
self._total_rejections = 0
self._last_failure_time: Optional[datetime] = None
self._last_success_time: Optional[datetime] = None

@property
def state(self) -> CircuitState:
"""Get current circuit state"""
with self._lock:
self._check_state_transition()
return self._state

@property
def is_closed(self) -> bool:
return self.state == CircuitState.CLOSED

@property
def is_open(self) -> bool:
return self.state == CircuitState.OPEN

@property
def is_half_open(self) -> bool:
return self.state == CircuitState.HALF_OPEN

def _check_state_transition(self):
"""Check if state should transition (called with lock held)"""
if self._state == CircuitState.OPEN:
# Check if recovery timeout has passed
elapsed = (datetime.utcnow() - self._state_changed_at).total_seconds()
if elapsed >= self.config.recovery_timeout:
self._transition_to(CircuitState.HALF_OPEN)

def _transition_to(self, new_state: CircuitState):
"""Transition to a new state (called with lock held)"""
old_state = self._state
self._state = new_state
self._state_changed_at = datetime.utcnow()

logger.info(
f"Circuit breaker '{self.name}' transitioned: "
f"{old_state.value} -> {new_state.value}"
)

# Reset half-open counters
if new_state == CircuitState.HALF_OPEN:
self._half_open_successes = 0
self._half_open_calls = 0
if self.config.on_half_open:
self.config.on_half_open(self)
elif new_state == CircuitState.OPEN:
if self.config.on_open:
self.config.on_open(self)
elif new_state == CircuitState.CLOSED:
self._failures.clear()
if self.config.on_close:
self.config.on_close(self)

def _count_recent_failures(self) -> int:
"""Count failures within the sliding window"""
now = datetime.utcnow()
cutoff = now - timedelta(seconds=self.config.fail_window_seconds)

# Remove old failures
while self._failures and self._failures[0].timestamp < cutoff:
self._failures.popleft()

return len(self._failures)

def _record_failure(self, failure_type: FailureType, exc: Optional[Exception] = None):
"""Record a failure"""
with self._lock:
record = FailureRecord(
timestamp=datetime.utcnow(),
failure_type=failure_type,
exception_type=type(exc).__name__ if exc else None,
message=str(exc) if exc else None
)
self._failures.append(record)
self._total_failures += 1
self._last_failure_time = record.timestamp

# Check if we should open the circuit
if self._state == CircuitState.CLOSED:
if self._count_recent_failures() >= self.config.fail_max:
self._transition_to(CircuitState.OPEN)

elif self._state == CircuitState.HALF_OPEN:
# Any failure in half-open state reopens the circuit
self._transition_to(CircuitState.OPEN)

def _record_success(self):
"""Record a success"""
with self._lock:
self._total_successes += 1
self._last_success_time = datetime.utcnow()

if self._state == CircuitState.HALF_OPEN:
self._half_open_successes += 1
if self._half_open_successes >= self.config.success_threshold:
self._transition_to(CircuitState.CLOSED)

def _should_allow_request(self) -> bool:
"""Check if a request should be allowed through"""
with self._lock:
self._check_state_transition()

if self._state == CircuitState.CLOSED:
return True

elif self._state == CircuitState.OPEN:
return False

elif self._state == CircuitState.HALF_OPEN:
# Allow limited requests for testing
if self._half_open_calls < self.config.half_open_max_calls:
self._half_open_calls += 1
return True
return False

return False

def _should_count_exception(self, exc: Exception) -> bool:
"""Check if exception should count as a failure"""
exc_type = type(exc)

# Check exclude list
for excluded in self.config.exclude_exceptions:
if isinstance(exc, excluded):
return False

# Check include list (if specified)
if self.config.include_exceptions is not None:
for included in self.config.include_exceptions:
if isinstance(exc, included):
return True
return False

return True

async def call(
self,
func: Callable[..., T],
*args,
timeout: Optional[float] = None,
**kwargs
) -> T:
"""
Call a function through the circuit breaker.

Args:
func: Async function to call
*args: Positional arguments for func
timeout: Override default timeout (None = use config)
**kwargs: Keyword arguments for func

Returns:
Result of the function call

Raises:
CircuitOpenError: If circuit is open
asyncio.TimeoutError: If call times out
Exception: Any exception from the function
"""
self._total_calls += 1

if not self._should_allow_request():
self._total_rejections += 1
elapsed = (datetime.utcnow() - self._state_changed_at).total_seconds()
time_remaining = max(0, self.config.recovery_timeout - elapsed)
raise CircuitOpenError(self.name, time_remaining)

# Determine timeout
call_timeout = timeout if timeout is not None else self.config.call_timeout

try:
if call_timeout:
result = await asyncio.wait_for(
func(*args, **kwargs),
timeout=call_timeout
)
else:
result = await func(*args, **kwargs)

self._record_success()
return result

except asyncio.TimeoutError as e:
self._record_failure(FailureType.TIMEOUT, e)
raise

except Exception as e:
if self._should_count_exception(e):
self._record_failure(FailureType.EXCEPTION, e)
raise

def call_sync(
self,
func: Callable[..., T],
*args,
**kwargs
) -> T:
"""
Call a synchronous function through the circuit breaker.

Note: No timeout support for sync calls.
"""
self._total_calls += 1

if not self._should_allow_request():
self._total_rejections += 1
elapsed = (datetime.utcnow() - self._state_changed_at).total_seconds()
time_remaining = max(0, self.config.recovery_timeout - elapsed)
raise CircuitOpenError(self.name, time_remaining)

try:
result = func(*args, **kwargs)
self._record_success()
return result

except Exception as e:
if self._should_count_exception(e):
self._record_failure(FailureType.EXCEPTION, e)
raise

def force_open(self):
"""Force the circuit to open (for testing/maintenance)"""
with self._lock:
self._transition_to(CircuitState.OPEN)

def force_close(self):
"""Force the circuit to close (for recovery)"""
with self._lock:
self._transition_to(CircuitState.CLOSED)

def reset(self):
"""Reset all state and metrics"""
with self._lock:
self._state = CircuitState.CLOSED
self._state_changed_at = datetime.utcnow()
self._failures.clear()
self._half_open_successes = 0
self._half_open_calls = 0
self._total_calls = 0
self._total_successes = 0
self._total_failures = 0
self._total_rejections = 0
self._last_failure_time = None
self._last_success_time = None

def get_metrics(self) -> CircuitMetrics:
"""Get current metrics"""
with self._lock:
total = self._total_successes + self._total_failures
failure_rate = self._total_failures / total if total > 0 else 0.0

return CircuitMetrics(
name=self.name,
state=self._state,
failure_count=self._total_failures,
success_count=self._total_successes,
rejection_count=self._total_rejections,
last_failure_time=self._last_failure_time,
last_success_time=self._last_success_time,
last_state_change=self._state_changed_at,
total_calls=self._total_calls,
failure_rate=failure_rate
)

=============================================================================

Circuit Breaker Registry

=============================================================================

class CircuitBreakerRegistry: """ Registry for managing multiple circuit breakers.

Provides:
- Automatic creation of circuit breakers
- Default configuration management
- Aggregate metrics
"""

def __init__(self, default_config: Optional[CircuitBreakerConfig] = None):
self.default_config = default_config or CircuitBreakerConfig()
self._breakers: Dict[str, CircuitBreaker] = {}
self._lock = threading.RLock()

def get(
self,
name: str,
config: Optional[CircuitBreakerConfig] = None
) -> CircuitBreaker:
"""Get or create a circuit breaker by name"""
with self._lock:
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
name=name,
config=config or self.default_config
)
return self._breakers[name]

def get_all(self) -> Dict[str, CircuitBreaker]:
"""Get all circuit breakers"""
with self._lock:
return dict(self._breakers)

def remove(self, name: str) -> bool:
"""Remove a circuit breaker"""
with self._lock:
if name in self._breakers:
del self._breakers[name]
return True
return False

def reset_all(self):
"""Reset all circuit breakers"""
with self._lock:
for breaker in self._breakers.values():
breaker.reset()

def get_all_metrics(self) -> List[CircuitMetrics]:
"""Get metrics for all circuit breakers"""
with self._lock:
return [b.get_metrics() for b in self._breakers.values()]

def get_open_circuits(self) -> List[str]:
"""Get names of all open circuits"""
with self._lock:
return [
name for name, breaker in self._breakers.items()
if breaker.is_open
]

def get_summary(self) -> Dict:
"""Get summary of all circuit breakers"""
with self._lock:
total = len(self._breakers)
closed = sum(1 for b in self._breakers.values() if b.is_closed)
open_count = sum(1 for b in self._breakers.values() if b.is_open)
half_open = sum(1 for b in self._breakers.values() if b.is_half_open)

return {
"total_breakers": total,
"closed": closed,
"open": open_count,
"half_open": half_open,
"open_circuits": self.get_open_circuits()
}

=============================================================================

Agent Circuit Breaker (Integration with AgentDiscoveryService)

=============================================================================

class AgentCircuitBreaker: """ Circuit breaker specifically for agent calls.

Integrates with AgentDiscoveryService to:
- Find fallback agents when circuit is open
- Track agent health across breakers
- Update agent status based on circuit state
"""

def __init__(
self,
discovery_service=None, # AgentDiscoveryService
config: Optional[CircuitBreakerConfig] = None
):
self.discovery = discovery_service
self.registry = CircuitBreakerRegistry(config)

# Track agent capabilities for fallback routing
self._agent_capabilities: Dict[str, Set[str]] = {}

def get_breaker(self, agent_id: str) -> CircuitBreaker:
"""Get circuit breaker for an agent"""
return self.registry.get(f"agent:{agent_id}")

async def call_agent(
self,
agent_id: str,
func: Callable[..., T],
*args,
fallback: bool = True,
capabilities: Optional[List[str]] = None,
**kwargs
) -> T:
"""
Call an agent function through circuit breaker with optional fallback.

Args:
agent_id: Target agent identifier
func: Async function to call
*args: Arguments for the function
fallback: Whether to try fallback agents if circuit is open
capabilities: Agent capabilities for fallback routing
**kwargs: Keyword arguments for the function

Returns:
Result from agent (or fallback)

Raises:
CircuitOpenError: If circuit is open and no fallback available
"""
# Store capabilities for fallback routing
if capabilities:
self._agent_capabilities[agent_id] = set(capabilities)

breaker = self.get_breaker(agent_id)

try:
return await breaker.call(func, *args, **kwargs)

except CircuitOpenError as e:
if not fallback:
raise

# Try to find a fallback agent
fallback_agent = await self._find_fallback_agent(agent_id)
if fallback_agent:
logger.info(
f"Circuit open for '{agent_id}', "
f"routing to fallback '{fallback_agent}'"
)
return await self.call_agent(
fallback_agent,
func,
*args,
fallback=False, # Don't recurse fallbacks
**kwargs
)

# No fallback available
raise

async def _find_fallback_agent(self, failing_agent_id: str) -> Optional[str]:
"""Find a fallback agent with same capabilities"""
if not self.discovery:
return None

# Get capabilities of failing agent
capabilities = self._agent_capabilities.get(failing_agent_id)
if not capabilities:
# Try to get from discovery service
agent = await self.discovery.get_agent(failing_agent_id)
if agent:
capabilities = {c.name for c in agent.capabilities}
self._agent_capabilities[failing_agent_id] = capabilities

if not capabilities:
return None

# Find agents with same capabilities
for capability in capabilities:
try:
candidates = await self.discovery.find_agents_by_capability(
capability
)

for candidate in candidates:
if candidate.id == failing_agent_id:
continue

# Check if candidate's circuit is healthy
candidate_breaker = self.get_breaker(candidate.id)
if candidate_breaker.is_closed or candidate_breaker.is_half_open:
return candidate.id

except Exception as e:
logger.warning(f"Error finding fallback: {e}")

return None

def get_agent_status(self, agent_id: str) -> Dict:
"""Get circuit breaker status for an agent"""
breaker = self.get_breaker(agent_id)
metrics = breaker.get_metrics()
return {
"agent_id": agent_id,
"circuit_state": metrics.state.value,
"failure_count": metrics.failure_count,
"success_count": metrics.success_count,
"failure_rate": metrics.failure_rate,
"last_failure": metrics.last_failure_time.isoformat() if metrics.last_failure_time else None
}

def get_all_agent_status(self) -> List[Dict]:
"""Get status for all tracked agents"""
return [
self.get_agent_status(name.replace("agent:", ""))
for name in self.registry.get_all().keys()
if name.startswith("agent:")
]

def force_open_agent(self, agent_id: str):
"""Force open circuit for an agent"""
self.get_breaker(agent_id).force_open()

def force_close_agent(self, agent_id: str):
"""Force close circuit for an agent"""
self.get_breaker(agent_id).force_close()

def reset_agent(self, agent_id: str):
"""Reset circuit breaker for an agent"""
self.get_breaker(agent_id).reset()

=============================================================================

Decorator

=============================================================================

Global registry for decorator usage

_global_registry = CircuitBreakerRegistry()

def circuit_breaker( name: str, config: Optional[CircuitBreakerConfig] = None, registry: Optional[CircuitBreakerRegistry] = None ): """ Decorator to wrap a function with circuit breaker.

Usage:
@circuit_breaker("my-service")
async def call_service():
...

@circuit_breaker("api", config=CircuitBreakerConfig(fail_max=3))
async def call_api():
...
"""
reg = registry or _global_registry

def decorator(func: Callable[..., T]) -> Callable[..., T]:
breaker = reg.get(name, config)

if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs) -> T:
return await breaker.call(func, *args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs) -> T:
return breaker.call_sync(func, *args, **kwargs)
return sync_wrapper

return decorator

def get_global_registry() -> CircuitBreakerRegistry: """Get the global circuit breaker registry""" return _global_registry

=============================================================================

Utility Functions

=============================================================================

def create_circuit_breaker( name: str, fail_max: int = 5, recovery_timeout: float = 60.0, **kwargs ) -> CircuitBreaker: """ Create a circuit breaker with common settings.

Args:
name: Circuit breaker name
fail_max: Failures before opening
recovery_timeout: Seconds to stay open
**kwargs: Additional CircuitBreakerConfig options

Returns:
Configured CircuitBreaker
"""
config = CircuitBreakerConfig(
fail_max=fail_max,
recovery_timeout=recovery_timeout,
**kwargs
)
return CircuitBreaker(name, config)

async def with_circuit_breaker( breaker: CircuitBreaker, func: Callable[..., T], *args, **kwargs ) -> T: """ Convenience function to call through a circuit breaker.

Usage:
breaker = create_circuit_breaker("my-service")
result = await with_circuit_breaker(breaker, my_func, arg1, arg2)
"""
return await breaker.call(func, *args, **kwargs)

=============================================================================

Health Check Integration

=============================================================================

class CircuitBreakerHealthCheck: """ Health check integration for circuit breakers.

Provides health status based on circuit breaker states.
"""

def __init__(self, registry: CircuitBreakerRegistry):
self.registry = registry
self.unhealthy_threshold = 0.5 # >50% open = unhealthy

def check(self) -> Dict:
"""Run health check"""
summary = self.registry.get_summary()
total = summary["total_breakers"]
open_count = summary["open"]

if total == 0:
status = "healthy"
message = "No circuit breakers registered"
elif open_count == 0:
status = "healthy"
message = f"All {total} circuits closed"
elif open_count / total > self.unhealthy_threshold:
status = "unhealthy"
message = f"{open_count}/{total} circuits open"
else:
status = "degraded"
message = f"{open_count}/{total} circuits open"

return {
"status": status,
"message": message,
"details": summary
}

=============================================================================

Main (Demo/Testing)

=============================================================================

async def main(): """Demo the circuit breaker system""" print("=" * 60) print("CODITECT Circuit Breaker Demo") print("=" * 60)

# Create a circuit breaker
config = CircuitBreakerConfig(
fail_max=3,
recovery_timeout=5,
success_threshold=2
)
breaker = CircuitBreaker("demo-service", config)

# Simulated service call
call_count = 0
async def flaky_service():
nonlocal call_count
call_count += 1
if call_count <= 4:
raise Exception(f"Service error (call {call_count})")
return f"Success (call {call_count})"

print("\n1. Testing circuit breaker with flaky service:")
print("-" * 60)

for i in range(10):
try:
result = await breaker.call(flaky_service)
print(f" Call {i+1}: {result} [State: {breaker.state.value}]")
except CircuitOpenError as e:
print(f" Call {i+1}: REJECTED - {e.breaker_name} [State: {breaker.state.value}]")
except Exception as e:
print(f" Call {i+1}: FAILED - {e} [State: {breaker.state.value}]")

await asyncio.sleep(0.5)

# Wait for recovery
print("\nWaiting for recovery timeout...")
await asyncio.sleep(5)

# Try again after recovery
print("\n2. After recovery timeout:")
print("-" * 60)

for i in range(5):
try:
result = await breaker.call(flaky_service)
print(f" Call {i+1}: {result} [State: {breaker.state.value}]")
except CircuitOpenError as e:
print(f" Call {i+1}: REJECTED [State: {breaker.state.value}]")
except Exception as e:
print(f" Call {i+1}: FAILED - {e} [State: {breaker.state.value}]")

# Show metrics
print("\n3. Final Metrics:")
print("-" * 60)
metrics = breaker.get_metrics()
for key, value in metrics.to_dict().items():
print(f" {key}: {value}")

# Test decorator
print("\n4. Testing decorator:")
print("-" * 60)

@circuit_breaker("decorated-service", config=CircuitBreakerConfig(fail_max=2))
async def decorated_call():
return "Decorated success!"

result = await decorated_call()
print(f" Result: {result}")

# Test registry
print("\n5. Registry Summary:")
print("-" * 60)
summary = get_global_registry().get_summary()
for key, value in summary.items():
print(f" {key}: {value}")

print("\n" + "=" * 60)
print("Demo complete!")

if name == "main": asyncio.run(main())