Skip to main content

#!/usr/bin/env python3 """ Integration tests for the Inter-Session Message Bus (ADR-160).

Tests cover:

  • H.13.2: MessageBus abstraction (SQLite transport, schema, retry)
  • H.13.4: Session registration & heartbeat
  • H.13.5: File conflict detection (advisory locks)
  • H.13.6: Task broadcasting & routing (claim/release, cross-LLM status)

Author: Claude (Opus 4.6) Created: 2026-02-06 Track: H.13.7.4 """

import os import sys import tempfile import threading import time import unittest from pathlib import Path

Ensure project root is importable

PROJECT_ROOT = Path(file).resolve().parent.parent.parent sys.path.insert(0, str(PROJECT_ROOT))

from scripts.core.session_message_bus import ( CHANNEL_TTL, FileLock, SessionInfo, SessionMessage, SQLiteSessionMessageBus, TaskClaim, get_session_message_bus, )

class TestSessionMessageBus(unittest.TestCase): """Base test class with temp database setup."""

def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.db_path = Path(self.tmpdir) / "messaging.db"
self.bus = SQLiteSessionMessageBus(
db_path=self.db_path, session_id="test-session"
)

def tearDown(self):
self.bus.close()
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)

class TestSchema(TestSessionMessageBus): """H.13.2: Schema and database initialization."""

def test_database_created(self):
self.assertTrue(self.db_path.exists())

def test_wal_mode(self):
import sqlite3
conn = sqlite3.connect(str(self.db_path))
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
conn.close()
self.assertEqual(mode, "wal")

def test_tables_exist(self):
import sqlite3
conn = sqlite3.connect(str(self.db_path))
tables = {
row[0]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
}
conn.close()
self.assertIn("session_registry", tables)
self.assertIn("inter_session_messages", tables)
self.assertIn("file_locks", tables)
self.assertIn("task_claims", tables)

class TestSessionRegistration(TestSessionMessageBus): """H.13.4: Session registration and heartbeat."""

def test_register_session(self):
self.bus.register_session(
"claude-100", "claude", "opus-4.6", project_id="PILOT", pid=100
)
sessions = self.bus.list_sessions()
self.assertEqual(len(sessions), 1)
self.assertEqual(sessions[0].session_id, "claude-100")
self.assertEqual(sessions[0].llm_vendor, "claude")

def test_register_multiple_vendors(self):
self.bus.register_session("claude-1", "claude", pid=1)
self.bus.register_session("codex-2", "codex", pid=2)
self.bus.register_session("kimi-3", "kimi", pid=3)
sessions = self.bus.list_sessions()
vendors = {s.llm_vendor for s in sessions}
self.assertEqual(vendors, {"claude", "codex", "kimi"})

def test_heartbeat(self):
self.bus.register_session("s1", "claude", pid=1)
first = self.bus.list_sessions()[0].heartbeat_at
time.sleep(1.1) # Need >1s for second-precision timestamps
self.bus.heartbeat("s1")
second = self.bus.list_sessions()[0].heartbeat_at
self.assertNotEqual(first, second)

def test_unregister(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.unregister_session("s1")
sessions = self.bus.list_sessions(active_only=False)
self.assertEqual(len(sessions), 0)

def test_stale_cleanup(self):
bus = SQLiteSessionMessageBus(
db_path=self.db_path, stale_timeout_seconds=0
)
bus.register_session("stale-1", "claude", pid=999)
time.sleep(1.1) # Must exceed stale_timeout (0s) + timestamp granularity
sessions = bus.list_sessions(active_only=True)
stale = [s for s in sessions if s.session_id == "stale-1"]
self.assertEqual(len(stale), 0)
bus.close()

def test_heartbeat_thread(self):
self.bus.register_session("ht-1", "claude", pid=1)
t = self.bus.start_heartbeat_thread(interval=0.1, session_id="ht-1")
self.assertTrue(t.is_alive())
time.sleep(0.3)
self.bus.stop_heartbeat_thread()
time.sleep(0.2)
self.assertFalse(t.is_alive())

class TestMessaging(TestSessionMessageBus): """H.13.2: Pub/sub messaging."""

def test_publish_and_poll(self):
msg_id = self.bus.publish("state", {"task": "H.13"})
self.assertGreater(msg_id, 0)
messages = self.bus.poll("state", since_id=0)
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0].payload["task"], "H.13")

def test_poll_respects_since_id(self):
id1 = self.bus.publish("ch", {"n": 1})
id2 = self.bus.publish("ch", {"n": 2})
msgs = self.bus.poll("ch", since_id=id1)
self.assertEqual(len(msgs), 1)
self.assertEqual(msgs[0].payload["n"], 2)

def test_channel_isolation(self):
self.bus.publish("alpha", {"x": 1})
self.bus.publish("beta", {"y": 2})
alpha_msgs = self.bus.poll("alpha")
beta_msgs = self.bus.poll("beta")
self.assertEqual(len(alpha_msgs), 1)
self.assertEqual(len(beta_msgs), 1)

def test_ttl_expiry(self):
self.bus.publish("short", {"data": 1}, ttl_seconds=0)
time.sleep(1.1) # Must exceed TTL + timestamp granularity (1s)
self.bus._cleanup_expired()
msgs = self.bus.poll("short")
self.assertEqual(len(msgs), 0)

def test_channel_ttl_defaults(self):
self.assertIn("task_broadcast", CHANNEL_TTL)
self.assertEqual(CHANNEL_TTL["task_broadcast"], 600)

class TestFileLocks(TestSessionMessageBus): """H.13.5: Advisory file locks."""

def test_lock_and_unlock(self):
self.bus.register_session("s1", "claude", pid=1)
acquired = self.bus.lock_file("test.py", session_id="s1")
self.assertTrue(acquired)
released = self.bus.unlock_file("test.py", session_id="s1")
self.assertTrue(released)

def test_conflict_detection(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.register_session("s2", "codex", pid=2)
self.bus.lock_file("shared.py", session_id="s1")
conflict = self.bus.lock_file("shared.py", session_id="s2")
self.assertFalse(conflict)

def test_same_session_relock(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.lock_file("file.py", session_id="s1")
relock = self.bus.lock_file("file.py", session_id="s1")
self.assertTrue(relock)

def test_stale_lock_takeover(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.lock_file("file.py", session_id="s1")
self.bus.unregister_session("s1")
self.bus.register_session("s2", "codex", pid=2)
takeover = self.bus.lock_file("file.py", session_id="s2")
self.assertTrue(takeover)

def test_get_file_locks(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.lock_file("a.py", session_id="s1")
self.bus.lock_file("b.py", session_id="s1")
locks = self.bus.get_file_locks()
self.assertEqual(len(locks), 2)
paths = {l.file_path for l in locks}
self.assertEqual(paths, {"a.py", "b.py"})

def test_unregister_cleans_locks(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.lock_file("locked.py", session_id="s1")
self.bus.unregister_session("s1")
locks = self.bus.get_file_locks()
self.assertEqual(len(locks), 0)

class TestTaskBroadcasting(TestSessionMessageBus): """H.13.6: Task broadcasting and routing."""

def test_broadcast_task(self):
msg_id = self.bus.broadcast_task("H.13.6.1", "started")
self.assertGreater(msg_id, 0)
msgs = self.bus.poll("task_broadcast")
self.assertEqual(len(msgs), 1)
self.assertEqual(msgs[0].payload["task_id"], "H.13.6.1")
self.assertEqual(msgs[0].payload["action"], "started")

def test_broadcast_with_details(self):
self.bus.broadcast_task("A.1", "completed", {"outcome": "success"})
msgs = self.bus.poll("task_broadcast")
self.assertEqual(msgs[0].payload["details"]["outcome"], "success")

def test_claim_task(self):
self.bus.register_session("s1", "claude", pid=1)
claimed = self.bus.claim_task("H.8.1", session_id="s1")
self.assertTrue(claimed)

def test_claim_conflict(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.register_session("s2", "codex", pid=2)
self.bus.claim_task("H.8.1", session_id="s1")
conflict = self.bus.claim_task("H.8.1", session_id="s2")
self.assertFalse(conflict)

def test_release_task(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.claim_task("task-1", session_id="s1")
released = self.bus.release_task("task-1", session_id="s1")
self.assertTrue(released)
claims = self.bus.get_task_claims()
self.assertEqual(len(claims), 0)

def test_stale_claim_takeover(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.claim_task("task-1", session_id="s1")
self.bus.unregister_session("s1")
self.bus.register_session("s2", "codex", pid=2)
claimed = self.bus.claim_task("task-1", session_id="s2")
self.assertTrue(claimed)

def test_get_task_claims(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.claim_task("t1", session_id="s1")
self.bus.claim_task("t2", session_id="s1")
claims = self.bus.get_task_claims()
self.assertEqual(len(claims), 2)
task_ids = {c.task_id for c in claims}
self.assertEqual(task_ids, {"t1", "t2"})

def test_unregister_cleans_claims(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.claim_task("t1", session_id="s1")
self.bus.unregister_session("s1")
claims = self.bus.get_task_claims()
s1_claims = [c for c in claims if c.session_id == "s1"]
self.assertEqual(len(s1_claims), 0)

class TestCrossLLMStatus(TestSessionMessageBus): """H.13.6.4: Cross-LLM status visibility."""

def test_get_cross_llm_status(self):
self.bus.register_session("c1", "claude", "opus-4.6", project_id="P1", pid=1)
self.bus.register_session("k1", "kimi", "k2", project_id="P1", pid=2)
self.bus.update_session_task("H.13", session_id="c1")
self.bus.update_session_task("F.4", session_id="k1")
status = self.bus.get_cross_llm_status()
self.assertEqual(len(status), 2)
vendors = {s["llm_vendor"] for s in status}
self.assertEqual(vendors, {"claude", "kimi"})

def test_cross_status_includes_claims(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.claim_task("H.13.7", session_id="s1")
status = self.bus.get_cross_llm_status()
self.assertEqual(status[0]["claimed_task"], "H.13.7")

def test_update_session_task(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.update_session_task("A.9.1", session_id="s1")
sessions = self.bus.list_sessions()
self.assertEqual(sessions[0].task_id, "A.9.1")

class TestStats(TestSessionMessageBus): """Stats and diagnostics."""

def test_stats(self):
self.bus.register_session("s1", "claude", pid=1)
self.bus.publish("ch", {"x": 1})
self.bus.lock_file("f.py", session_id="s1")
self.bus.claim_task("t1", session_id="s1") # Also publishes a broadcast
stats = self.bus.stats()
self.assertEqual(stats["active_sessions"], 1)
self.assertGreaterEqual(stats["pending_messages"], 1) # >= 1 (publish + claim broadcast)
self.assertEqual(stats["file_locks"], 1)
self.assertEqual(stats["task_claims"], 1)
self.assertGreater(stats["db_size_bytes"], 0)

class TestConcurrency(TestSessionMessageBus): """Concurrent access correctness."""

def test_concurrent_publish(self):
"""10 threads publishing 50 messages each should not lose any."""
errors = []

def publisher(n):
try:
b = SQLiteSessionMessageBus(
db_path=self.db_path, session_id=f"pub-{n}"
)
for i in range(50):
b.publish("concurrent", {"n": n, "i": i})
b.close()
except Exception as e:
errors.append(str(e))

threads = [threading.Thread(target=publisher, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(len(errors), 0, f"Errors: {errors}")
msgs = self.bus.poll("concurrent", limit=1000)
self.assertEqual(len(msgs), 500)

def test_concurrent_claim_no_double_assign(self):
"""Only one session should win a claim under contention."""
for i in range(10):
self.bus.register_session(f"contender-{i}", "claude", pid=100 + i)

results = []

def claimer(sid):
b = SQLiteSessionMessageBus(db_path=self.db_path, session_id=sid)
won = b.claim_task("contested-task", session_id=sid)
results.append((sid, won))
b.close()

threads = [
threading.Thread(target=claimer, args=(f"contender-{i}",))
for i in range(10)
]
for t in threads:
t.start()
for t in threads:
t.join()

winners = [sid for sid, won in results if won]
# At least one winner, and claim state should be consistent
claims = self.bus.get_task_claims()
contested = [c for c in claims if c.task_id == "contested-task"]
self.assertEqual(len(contested), 1)

class TestSessionContext(TestSessionMessageBus): """H.13.8: Machine identity, task tracking, active files, CWD, last_active_at."""

def test_register_with_all_fields(self):
"""Registration populates machine_uuid, hostname, cwd, task_id."""
self.bus.register_session(
"full-1", "anthropic", "opus-4.6",
project_id="PILOT", task_id="H.13.8",
cwd="/tmp/test", machine_uuid="uuid-abc", hostname="dev-box",
)
sessions = self.bus.list_sessions()
s = sessions[0]
self.assertEqual(s.task_id, "H.13.8")
self.assertEqual(s.cwd, "/tmp/test")
self.assertEqual(s.machine_uuid, "uuid-abc")
self.assertEqual(s.hostname, "dev-box")
self.assertIsNotNone(s.last_active_at)

def test_register_auto_detects_cwd_and_pid(self):
"""CWD and PID auto-detected when not provided."""
self.bus.register_session("auto-1", "claude")
sessions = self.bus.list_sessions()
s = sessions[0]
self.assertIsNotNone(s.cwd)
self.assertIsNotNone(s.pid)
self.assertEqual(s.pid, os.getpid())
self.assertEqual(s.cwd, os.getcwd())

def test_register_auto_detects_hostname(self):
"""Hostname auto-detected when not provided."""
self.bus.register_session("host-1", "claude")
sessions = self.bus.list_sessions()
s = sessions[0]
self.assertIsNotNone(s.hostname)
import socket
self.assertEqual(s.hostname, socket.gethostname())

def test_update_session_context_task(self):
"""update_session_context changes task_id and updates last_active_at."""
self.bus.register_session("ctx-1", "claude", task_id="H.13.7")
sessions = self.bus.list_sessions()
initial_active = sessions[0].last_active_at

time.sleep(0.05) # Ensure timestamp changes
self.bus.update_session_context(session_id="ctx-1", task_id="H.13.8")

sessions = self.bus.list_sessions()
s = sessions[0]
self.assertEqual(s.task_id, "H.13.8")
self.assertGreaterEqual(s.last_active_at, initial_active)

def test_update_session_context_active_files(self):
"""update_session_context sets active_files as JSON list."""
self.bus.register_session("ctx-2", "claude")
self.bus.update_session_context(
session_id="ctx-2",
active_files=["paths.py", "session_message_bus.py"],
)
sessions = self.bus.list_sessions()
s = sessions[0]
self.assertEqual(s.active_files, ["paths.py", "session_message_bus.py"])

def test_update_session_context_cwd(self):
"""update_session_context changes cwd."""
self.bus.register_session("ctx-3", "codex", cwd="/old/path")
self.bus.update_session_context(session_id="ctx-3", cwd="/new/path")
sessions = self.bus.list_sessions()
self.assertEqual(sessions[0].cwd, "/new/path")

def test_update_session_context_project(self):
"""update_session_context changes project_id."""
self.bus.register_session("ctx-4", "gemini", project_id="OLD")
self.bus.update_session_context(session_id="ctx-4", project_id="NEW")
sessions = self.bus.list_sessions()
self.assertEqual(sessions[0].project_id, "NEW")

def test_update_session_context_partial(self):
"""update_session_context only changes specified fields."""
self.bus.register_session(
"ctx-5", "claude", task_id="A.1", cwd="/original", project_id="PILOT"
)
self.bus.update_session_context(session_id="ctx-5", task_id="A.2")
sessions = self.bus.list_sessions()
s = sessions[0]
self.assertEqual(s.task_id, "A.2")
self.assertEqual(s.cwd, "/original")
self.assertEqual(s.project_id, "PILOT")

def test_cross_llm_status_includes_new_fields(self):
"""get_cross_llm_status returns machine_uuid, hostname, cwd, last_active_at."""
self.bus.register_session(
"status-1", "anthropic", "opus-4.6",
project_id="PILOT", task_id="H.13.8",
cwd="/work", machine_uuid="m-123", hostname="laptop",
)
self.bus.update_session_context(
session_id="status-1",
active_files=["bus.py"],
)
status = self.bus.get_cross_llm_status()
self.assertEqual(len(status), 1)
s = status[0]
self.assertEqual(s["machine_uuid"], "m-123")
self.assertEqual(s["hostname"], "laptop")
self.assertEqual(s["cwd"], "/work")
self.assertEqual(s["task_id"], "H.13.8")
self.assertEqual(s["active_files"], ["bus.py"])
self.assertIsNotNone(s["last_active_at"])

def test_update_session_task_sets_last_active(self):
"""update_session_task also updates last_active_at."""
self.bus.register_session("task-1", "claude")
sessions = self.bus.list_sessions()
initial = sessions[0].last_active_at

time.sleep(0.05)
self.bus.update_session_task("J.27.4", session_id="task-1")

sessions = self.bus.list_sessions()
self.assertEqual(sessions[0].task_id, "J.27.4")
self.assertGreaterEqual(sessions[0].last_active_at, initial)

def test_schema_migration_existing_db(self):
"""Existing database without new columns gets migrated."""
import sqlite3
# Create a db with OLD schema (no cwd, machine_uuid, hostname, last_active_at)
old_db = Path(self.tmpdir) / "old.db"
conn = sqlite3.connect(str(old_db))
conn.execute("""
CREATE TABLE session_registry (
session_id TEXT PRIMARY KEY,
llm_vendor TEXT NOT NULL,
llm_model TEXT,
tty TEXT,
pid INTEGER,
project_id TEXT,
task_id TEXT,
active_files TEXT,
heartbeat_at TEXT NOT NULL,
registered_at TEXT NOT NULL DEFAULT (datetime('now')),
status TEXT DEFAULT 'active'
)
""")
conn.execute("""
CREATE TABLE inter_session_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sender_id TEXT NOT NULL,
channel TEXT NOT NULL,
payload TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
ttl_seconds INTEGER DEFAULT 300
)
""")
conn.execute("""
CREATE TABLE file_locks (
file_path TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
lock_type TEXT DEFAULT 'advisory',
locked_at TEXT NOT NULL DEFAULT (datetime('now'))
)
""")
conn.execute("""
CREATE TABLE task_claims (
task_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
claimed_at TEXT NOT NULL DEFAULT (datetime('now')),
status TEXT DEFAULT 'claimed'
)
""")
conn.commit()
conn.close()

# Open with new code - migration should add missing columns
bus2 = SQLiteSessionMessageBus(db_path=old_db)
bus2.register_session(
"migrated-1", "claude", cwd="/migrated",
machine_uuid="m-999", hostname="migrated-host",
)
sessions = bus2.list_sessions()
self.assertEqual(sessions[0].cwd, "/migrated")
self.assertEqual(sessions[0].machine_uuid, "m-999")
self.assertEqual(sessions[0].hostname, "migrated-host")
self.assertIsNotNone(sessions[0].last_active_at)
bus2.close()

class TestFactory(unittest.TestCase): """Factory function tests."""

def test_get_session_message_bus_custom_path(self):
tmpdir = tempfile.mkdtemp()
db = Path(tmpdir) / "test.db"
bus = get_session_message_bus(db_path=db)
self.assertIsNotNone(bus)
bus.close()
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)

if name == "main": unittest.main(verbosity=2)