Skip to main content

#!/usr/bin/env python3 """ Unit Tests for CODITECT Message Bus

Part of Track H.2: Inter-Agent Communication Infrastructure

Tests cover:

  • AgentMessage serialization/deserialization
  • MessageBusConfig from various sources
  • LocalMessageQueue functionality
  • MessageBus connection and disconnection
  • Message sending and receiving
  • Health check functionality

Run tests: cd scripts/core python -m pytest test_message_bus.py -v

# Or directly
python test_message_bus.py

Author: CODITECT Framework Created: January 8, 2026 Version: 1.0.0 """

import asyncio import json import os import sys import tempfile import threading import time import unittest from datetime import datetime from pathlib import Path from typing import List from unittest.mock import AsyncMock, MagicMock, patch

Add script directory to path

script_dir = Path(file).parent sys.path.insert(0, str(script_dir))

from message_bus import ( AgentMessage, MessageBus, MessageBusConfig, MessageType, MessagePriority, LocalMessageQueue, create_task_message, )

class TestAgentMessage(unittest.TestCase): """Test AgentMessage dataclass."""

def test_create_default_message(self):
"""Test creating a message with default values."""
msg = AgentMessage()

self.assertIsNotNone(msg.id)
self.assertEqual(len(msg.id), 36) # UUID format
self.assertEqual(msg.from_agent, "")
self.assertEqual(msg.to_agent, "")
self.assertEqual(msg.message_type, MessageType.TASK_REQUEST.value)
self.assertEqual(msg.priority, MessagePriority.NORMAL.value)
self.assertIsInstance(msg.payload, dict)
self.assertIsInstance(msg.metadata, dict)

def test_create_custom_message(self):
"""Test creating a message with custom values."""
msg = AgentMessage(
from_agent="agent-a",
to_agent="agent-b",
task_id="task-123",
message_type=MessageType.TASK_RESPONSE.value,
payload={"result": "success"},
priority=MessagePriority.HIGH.value,
metadata={"attempt": 1},
)

self.assertEqual(msg.from_agent, "agent-a")
self.assertEqual(msg.to_agent, "agent-b")
self.assertEqual(msg.task_id, "task-123")
self.assertEqual(msg.message_type, "task_response")
self.assertEqual(msg.payload, {"result": "success"})
self.assertEqual(msg.priority, 7)
self.assertEqual(msg.metadata, {"attempt": 1})

def test_to_dict(self):
"""Test converting message to dictionary."""
msg = AgentMessage(
from_agent="agent-a",
to_agent="agent-b",
task_id="task-123",
)

data = msg.to_dict()

self.assertIsInstance(data, dict)
self.assertEqual(data["from_agent"], "agent-a")
self.assertEqual(data["to_agent"], "agent-b")
self.assertEqual(data["task_id"], "task-123")
self.assertIn("id", data)
self.assertIn("timestamp", data)
self.assertIn("correlation_id", data)

def test_to_json(self):
"""Test converting message to JSON string."""
msg = AgentMessage(
from_agent="agent-a",
to_agent="agent-b",
payload={"key": "value"},
)

json_str = msg.to_json()

self.assertIsInstance(json_str, str)
parsed = json.loads(json_str)
self.assertEqual(parsed["from_agent"], "agent-a")
self.assertEqual(parsed["payload"], {"key": "value"})

def test_from_dict(self):
"""Test creating message from dictionary."""
data = {
"id": "msg-123",
"from_agent": "agent-a",
"to_agent": "agent-b",
"task_id": "task-456",
"message_type": "event",
"payload": {"event": "completed"},
"correlation_id": "corr-789",
"timestamp": "2026-01-08T12:00:00",
"priority": 9,
}

msg = AgentMessage.from_dict(data)

self.assertEqual(msg.id, "msg-123")
self.assertEqual(msg.from_agent, "agent-a")
self.assertEqual(msg.to_agent, "agent-b")
self.assertEqual(msg.task_id, "task-456")
self.assertEqual(msg.message_type, "event")
self.assertEqual(msg.payload, {"event": "completed"})
self.assertEqual(msg.priority, 9)

def test_from_json(self):
"""Test creating message from JSON string."""
json_str = json.dumps({
"from_agent": "agent-a",
"to_agent": "agent-b",
"payload": {"data": 123},
})

msg = AgentMessage.from_json(json_str)

self.assertEqual(msg.from_agent, "agent-a")
self.assertEqual(msg.to_agent, "agent-b")
self.assertEqual(msg.payload, {"data": 123})

def test_from_dict_ignores_unknown_fields(self):
"""Test that unknown fields are ignored when creating from dict."""
data = {
"from_agent": "agent-a",
"unknown_field": "should be ignored",
"another_unknown": 123,
}

msg = AgentMessage.from_dict(data)

self.assertEqual(msg.from_agent, "agent-a")
self.assertFalse(hasattr(msg, "unknown_field"))

def test_round_trip_serialization(self):
"""Test that message survives JSON round-trip."""
original = AgentMessage(
from_agent="agent-a",
to_agent="agent-b",
task_id="task-123",
payload={"nested": {"data": [1, 2, 3]}},
metadata={"key": "value"},
)

json_str = original.to_json()
restored = AgentMessage.from_json(json_str)

self.assertEqual(original.from_agent, restored.from_agent)
self.assertEqual(original.to_agent, restored.to_agent)
self.assertEqual(original.task_id, restored.task_id)
self.assertEqual(original.payload, restored.payload)
self.assertEqual(original.metadata, restored.metadata)

class TestMessageBusConfig(unittest.TestCase): """Test MessageBusConfig dataclass."""

def test_default_config(self):
"""Test default configuration values."""
config = MessageBusConfig()

self.assertEqual(config.host, "localhost")
self.assertEqual(config.port, 5672)
self.assertEqual(config.username, "coditect")
self.assertEqual(config.password, "coditect_dev_2026")
self.assertEqual(config.vhost, "coditect")
self.assertEqual(config.heartbeat, 60)
self.assertFalse(config.use_ssl)

def test_custom_config(self):
"""Test custom configuration values."""
config = MessageBusConfig(
host="rabbitmq.example.com",
port=5673,
username="custom_user",
password="custom_pass",
vhost="custom_vhost",
use_ssl=True,
)

self.assertEqual(config.host, "rabbitmq.example.com")
self.assertEqual(config.port, 5673)
self.assertEqual(config.username, "custom_user")
self.assertTrue(config.use_ssl)

def test_url_property(self):
"""Test URL generation."""
config = MessageBusConfig(
host="localhost",
port=5672,
username="user",
password="pass",
vhost="test",
)

self.assertEqual(config.url, "amqp://user:pass@localhost:5672/test")

def test_url_with_ssl(self):
"""Test URL generation with SSL."""
config = MessageBusConfig(
host="secure.host.com",
use_ssl=True,
username="user",
password="pass",
vhost="prod",
)

self.assertTrue(config.url.startswith("amqps://"))

def test_from_env(self):
"""Test creating config from environment variables."""
env_vars = {
"RABBITMQ_HOST": "env-host",
"RABBITMQ_PORT": "5673",
"RABBITMQ_USER": "env-user",
"RABBITMQ_PASS": "env-pass",
"RABBITMQ_VHOST": "env-vhost",
"RABBITMQ_SSL": "true",
}

with patch.dict(os.environ, env_vars, clear=False):
config = MessageBusConfig.from_env()

self.assertEqual(config.host, "env-host")
self.assertEqual(config.port, 5673)
self.assertEqual(config.username, "env-user")
self.assertEqual(config.password, "env-pass")
self.assertEqual(config.vhost, "env-vhost")
self.assertTrue(config.use_ssl)

def test_from_file(self):
"""Test creating config from JSON file."""
config_data = {
"host": "file-host",
"port": 5674,
"username": "file-user",
"password": "file-pass",
"vhost": "file-vhost",
}

with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
temp_path = f.name

try:
config = MessageBusConfig.from_file(temp_path)

self.assertEqual(config.host, "file-host")
self.assertEqual(config.port, 5674)
self.assertEqual(config.username, "file-user")
finally:
os.unlink(temp_path)

class TestLocalMessageQueue(unittest.TestCase): """Test LocalMessageQueue for development mode."""

def setUp(self):
"""Set up test fixtures."""
self.queue = LocalMessageQueue()

def tearDown(self):
"""Clean up."""
self.queue.stop_processing()

def test_get_or_create_queue(self):
"""Test queue creation on demand."""
q1 = self.queue.get_queue("test-queue")
q2 = self.queue.get_queue("test-queue")

self.assertIs(q1, q2) # Same queue returned

def test_publish_message(self):
"""Test publishing a message to a queue."""
msg = AgentMessage(from_agent="test", payload={"data": "test"})

self.queue.publish("test-queue", msg)

q = self.queue.get_queue("test-queue")
self.assertEqual(q.qsize(), 1)

def test_subscribe_and_receive(self):
"""Test subscribing and receiving messages."""
received_messages: List[AgentMessage] = []

def callback(msg: AgentMessage):
received_messages.append(msg)

self.queue.subscribe("test-queue", callback)
self.queue.start_processing()

# Publish message
msg = AgentMessage(from_agent="sender", payload={"test": True})
self.queue.publish("test-queue", msg)

# Wait for processing
time.sleep(0.3)

self.assertEqual(len(received_messages), 1)
self.assertEqual(received_messages[0].from_agent, "sender")

def test_multiple_subscribers(self):
"""Test multiple subscribers to same queue."""
received_1: List[AgentMessage] = []
received_2: List[AgentMessage] = []

self.queue.subscribe("shared-queue", lambda m: received_1.append(m))
self.queue.subscribe("shared-queue", lambda m: received_2.append(m))
self.queue.start_processing()

msg = AgentMessage(payload={"shared": True})
self.queue.publish("shared-queue", msg)

time.sleep(0.3)

self.assertEqual(len(received_1), 1)
self.assertEqual(len(received_2), 1)

def test_multiple_queues(self):
"""Test messages go to correct queues."""
queue_a_messages: List[AgentMessage] = []
queue_b_messages: List[AgentMessage] = []

self.queue.subscribe("queue-a", lambda m: queue_a_messages.append(m))
self.queue.subscribe("queue-b", lambda m: queue_b_messages.append(m))
self.queue.start_processing()

self.queue.publish("queue-a", AgentMessage(payload={"queue": "a"}))
self.queue.publish("queue-b", AgentMessage(payload={"queue": "b"}))
self.queue.publish("queue-a", AgentMessage(payload={"queue": "a2"}))

time.sleep(0.3)

self.assertEqual(len(queue_a_messages), 2)
self.assertEqual(len(queue_b_messages), 1)

class TestMessageBusLocalMode(unittest.TestCase): """Test MessageBus in local fallback mode (no RabbitMQ)."""

def test_connect_local_mode(self):
"""Test connecting in local mode."""
async def run_test():
bus = MessageBus()
connected = await bus.connect(use_local_fallback=True)

self.assertTrue(connected)
self.assertTrue(bus.is_connected)
self.assertTrue(bus.is_local_mode)

await bus.disconnect()

asyncio.run(run_test())

def test_disconnect(self):
"""Test disconnecting from bus."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

self.assertTrue(bus.is_connected)

await bus.disconnect()

self.assertFalse(bus.is_connected)

asyncio.run(run_test())

def test_send_task_requires_connection(self):
"""Test that sending requires connection."""
async def run_test():
bus = MessageBus()

with self.assertRaises(RuntimeError) as ctx:
await bus.send_task(
from_agent="a",
to_agent="b",
task_id="t1",
payload={},
)

self.assertIn("Not connected", str(ctx.exception))

asyncio.run(run_test())

def test_send_task_local_mode(self):
"""Test sending a task in local mode."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

correlation_id = await bus.send_task(
from_agent="orchestrator",
to_agent="test-agent",
task_id="task-123",
payload={"action": "test"},
priority=MessagePriority.HIGH.value,
)

self.assertIsNotNone(correlation_id)
self.assertEqual(len(correlation_id), 36) # UUID format

await bus.disconnect()

asyncio.run(run_test())

def test_broadcast_event_local_mode(self):
"""Test broadcasting an event in local mode."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

# Should not raise
await bus.broadcast_event(
from_agent="system",
event_type="status_update",
payload={"status": "running"},
)

await bus.disconnect()

asyncio.run(run_test())

def test_publish_event_local_mode(self):
"""Test publishing a topic event in local mode."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

await bus.publish_event(
from_agent="monitor",
event_type="metric",
payload={"cpu": 50},
routing_key="metrics.cpu",
)

await bus.disconnect()

asyncio.run(run_test())

class TestCreateTaskMessage(unittest.TestCase): """Test create_task_message helper function."""

def test_create_task_message(self):
"""Test creating a task message with helper."""
msg = create_task_message(
from_agent="orchestrator",
to_agent="worker",
task_id="task-123",
payload={"command": "process"},
)

self.assertEqual(msg.from_agent, "orchestrator")
self.assertEqual(msg.to_agent, "worker")
self.assertEqual(msg.task_id, "task-123")
self.assertEqual(msg.message_type, MessageType.TASK_REQUEST.value)
self.assertEqual(msg.payload, {"command": "process"})
self.assertEqual(msg.reply_to, "agent.orchestrator.responses")

def test_create_task_message_with_priority(self):
"""Test creating high-priority task message."""
msg = create_task_message(
from_agent="orchestrator",
to_agent="worker",
task_id="urgent-123",
payload={},
priority=MessagePriority.CRITICAL.value,
)

self.assertEqual(msg.priority, 10)

class TestMessageTypeEnum(unittest.TestCase): """Test MessageType enum."""

def test_message_types(self):
"""Test all message type values."""
self.assertEqual(MessageType.TASK_REQUEST.value, "task_request")
self.assertEqual(MessageType.TASK_RESPONSE.value, "task_response")
self.assertEqual(MessageType.EVENT.value, "event")
self.assertEqual(MessageType.QUERY.value, "query")
self.assertEqual(MessageType.HEARTBEAT.value, "heartbeat")
self.assertEqual(MessageType.BROADCAST.value, "broadcast")
self.assertEqual(MessageType.ERROR.value, "error")

class TestMessagePriorityEnum(unittest.TestCase): """Test MessagePriority enum."""

def test_priority_order(self):
"""Test priority values are ordered correctly."""
self.assertLess(MessagePriority.LOWEST.value, MessagePriority.LOW.value)
self.assertLess(MessagePriority.LOW.value, MessagePriority.NORMAL.value)
self.assertLess(MessagePriority.NORMAL.value, MessagePriority.HIGH.value)
self.assertLess(MessagePriority.HIGH.value, MessagePriority.URGENT.value)
self.assertLess(MessagePriority.URGENT.value, MessagePriority.CRITICAL.value)

def test_priority_values(self):
"""Test specific priority values."""
self.assertEqual(MessagePriority.LOWEST.value, 1)
self.assertEqual(MessagePriority.NORMAL.value, 5)
self.assertEqual(MessagePriority.CRITICAL.value, 10)

class TestMessageBusQueueStats(unittest.TestCase): """Test queue statistics functionality."""

def test_get_queue_stats_local_mode(self):
"""Test getting queue stats in local mode."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

# Publish some messages
await bus.send_task(
from_agent="a",
to_agent="b",
task_id="t1",
payload={},
)

stats = await bus.get_queue_stats()

self.assertIsInstance(stats, dict)

await bus.disconnect()

asyncio.run(run_test())

class TestMessageBusSendResponse(unittest.TestCase): """Test response sending functionality."""

def test_send_response_local_mode(self):
"""Test sending a response in local mode."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

original = AgentMessage(
from_agent="requester",
to_agent="responder",
task_id="task-123",
reply_to="agent.requester.responses",
)

# Should not raise
await bus.send_response(
original_message=original,
payload={"result": "done"},
success=True,
)

await bus.disconnect()

asyncio.run(run_test())

def test_send_error_response(self):
"""Test sending an error response."""
async def run_test():
bus = MessageBus()
await bus.connect(use_local_fallback=True)

original = AgentMessage(
from_agent="requester",
to_agent="responder",
task_id="task-456",
reply_to="agent.requester.responses",
)

await bus.send_response(
original_message=original,
payload={"error": "Something failed"},
success=False,
)

await bus.disconnect()

asyncio.run(run_test())

def run_tests(): """Run all tests and return results.""" loader = unittest.TestLoader() suite = unittest.TestSuite()

# Add all test classes
test_classes = [
TestAgentMessage,
TestMessageBusConfig,
TestLocalMessageQueue,
TestMessageBusLocalMode,
TestCreateTaskMessage,
TestMessageTypeEnum,
TestMessagePriorityEnum,
TestMessageBusQueueStats,
TestMessageBusSendResponse,
]

for test_class in test_classes:
suite.addTests(loader.loadTestsFromTestCase(test_class))

# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)

# Print summary
print("\n" + "="*70)
print("TEST SUMMARY")
print("="*70)
print(f"Tests Run: {result.testsRun}")
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped)}")

success = result.wasSuccessful()
print(f"\nOverall: {'PASSED' if success else 'FAILED'}")
print("="*70)

return result

if name == "main": result = run_tests() sys.exit(0 if result.wasSuccessful() else 1)