Skip to main content

#!/usr/bin/env python3 """ Tests for Codanna MCP Security Wrapper

Run with: python3 -m pytest scripts/tests/test_codanna_mcp_wrapper.py -v """

import json import pytest import sys import tempfile from pathlib import Path from unittest.mock import Mock, patch

Add parent directory to path for imports

sys.path.insert(0, str(Path(file).parent.parent))

Import from hyphenated module name using importlib

import importlib.util spec = importlib.util.spec_from_file_location( "codanna_mcp_wrapper", Path(file).parent.parent / "codanna-mcp-wrapper.py" ) codanna_mcp_wrapper = importlib.util.module_from_spec(spec) spec.loader.exec_module(codanna_mcp_wrapper)

SecurityConfig = codanna_mcp_wrapper.SecurityConfig TenantContext = codanna_mcp_wrapper.TenantContext InputValidator = codanna_mcp_wrapper.InputValidator RateLimiter = codanna_mcp_wrapper.RateLimiter AuditLogger = codanna_mcp_wrapper.AuditLogger

=============================================================================

TenantContext Tests

=============================================================================

class TestTenantContext: """Tests for TenantContext."""

def test_valid_tenant_id(self):
"""Valid tenant IDs should be accepted."""
tenant = TenantContext(tenant_id="abc123")
assert tenant.tenant_id == "abc123"

def test_tenant_id_with_special_chars(self):
"""Tenant IDs with underscores and hyphens should work."""
tenant = TenantContext(tenant_id="tenant_123-abc")
assert tenant.tenant_id == "tenant_123-abc"

def test_invalid_tenant_id_special_chars(self):
"""Tenant IDs with invalid characters should be rejected."""
with pytest.raises(ValueError, match="Invalid tenant_id format"):
TenantContext(tenant_id="tenant/123")

def test_invalid_tenant_id_too_long(self):
"""Tenant IDs over 64 characters should be rejected."""
with pytest.raises(ValueError, match="Invalid tenant_id format"):
TenantContext(tenant_id="a" * 65)

def test_invalid_tenant_id_empty(self):
"""Empty tenant IDs should be rejected."""
with pytest.raises(ValueError, match="Invalid tenant_id format"):
TenantContext(tenant_id="")

def test_index_path(self):
"""Index path should be tenant-isolated."""
with tempfile.TemporaryDirectory() as tmpdir:
tenant = TenantContext(
tenant_id="tenant123",
workspace_root=Path(tmpdir)
)
expected = Path(tmpdir) / ".coditect" / ".codanna" / "tenant123"
assert tenant.index_path == expected

def test_ensure_index_directory(self):
"""Should create index directory with correct permissions."""
with tempfile.TemporaryDirectory() as tmpdir:
tenant = TenantContext(
tenant_id="tenant123",
workspace_root=Path(tmpdir)
)
index_path = tenant.ensure_index_directory(0o700)

assert index_path.exists()
assert index_path.is_dir()
# Check permissions (owner only)
mode = index_path.stat().st_mode & 0o777
assert mode == 0o700

=============================================================================

InputValidator Tests

=============================================================================

class TestInputValidator: """Tests for InputValidator."""

@pytest.fixture
def validator(self):
return InputValidator(SecurityConfig())

def test_valid_message(self, validator):
"""Valid messages should pass validation."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "semantic_search",
"params": {
"query": "authentication handlers"
}
}
valid, error = validator.validate_message(message)
assert valid is True
assert error is None

def test_message_too_large(self, validator):
"""Messages over size limit should be rejected."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "test",
"params": {
"data": "x" * 2_000_000 # 2MB
}
}
valid, error = validator.validate_message(message)
assert valid is False
assert "size limit" in error

def test_query_too_long(self, validator):
"""Queries over length limit should be rejected."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "semantic_search",
"params": {
"query": "x" * 20000
}
}
valid, error = validator.validate_message(message)
assert valid is False
assert "length limit" in error

def test_query_with_json_injection(self, validator):
"""Queries with JSON control characters should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "semantic_search",
"params": {
"query": '"}}, {"malicious": true}'
}
}
valid, error = validator.validate_message(message)
assert valid is False
assert "blocked pattern" in error

def test_query_with_curly_braces(self, validator):
"""Queries with curly braces should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "semantic_search",
"params": {
"query": "find {something}"
}
}
valid, error = validator.validate_message(message)
assert valid is False

def test_path_traversal_blocked(self, validator):
"""Path traversal attempts should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "find_symbol",
"params": {
"path": "../../../etc/passwd"
}
}
valid, error = validator.validate_message(message)
assert valid is False
assert "Path traversal" in error

def test_absolute_path_outside_workspace(self, validator):
"""Absolute paths outside workspace should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "find_symbol",
"params": {
"path": "/etc/passwd"
}
}
valid, error = validator.validate_message(message)
assert valid is False

def test_home_directory_reference(self, validator):
"""Home directory references should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "find_symbol",
"params": {
"path": "~/.ssh/id_rsa"
}
}
valid, error = validator.validate_message(message)
assert valid is False

def test_null_byte_in_query(self, validator):
"""Null bytes should be blocked."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "semantic_search",
"params": {
"query": "test\x00injection"
}
}
valid, error = validator.validate_message(message)
assert valid is False

def test_valid_relative_path(self, validator):
"""Valid relative paths should be accepted."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "find_symbol",
"params": {
"path": "src/main.rs"
}
}
valid, error = validator.validate_message(message)
assert valid is True

def test_valid_workspace_absolute_path(self, validator):
"""Absolute paths within /workspace should be accepted."""
message = {
"jsonrpc": "2.0",
"id": 1,
"method": "find_symbol",
"params": {
"path": "/workspace/src/main.rs"
}
}
valid, error = validator.validate_message(message)
assert valid is True

=============================================================================

RateLimiter Tests

=============================================================================

class TestRateLimiter: """Tests for RateLimiter."""

def test_allows_requests_under_limit(self):
"""Should allow requests under the limit."""
limiter = RateLimiter(max_requests=10, window_seconds=60)

for _ in range(10):
assert limiter.allow_request() is True

def test_blocks_requests_over_limit(self):
"""Should block requests over the limit."""
limiter = RateLimiter(max_requests=5, window_seconds=60)

# Use up the limit
for _ in range(5):
limiter.allow_request()

# Next request should be blocked
assert limiter.allow_request() is False

def test_stats(self):
"""Should report accurate stats."""
limiter = RateLimiter(max_requests=10, window_seconds=60)

for _ in range(3):
limiter.allow_request()

stats = limiter.get_stats()
assert stats["current_requests"] == 3
assert stats["max_requests"] == 10
assert stats["remaining"] == 7

=============================================================================

AuditLogger Tests

=============================================================================

class TestAuditLogger: """Tests for AuditLogger."""

def test_log_event(self):
"""Should log events without error."""
with tempfile.TemporaryDirectory() as tmpdir:
tenant = TenantContext(
tenant_id="test123",
workspace_root=Path(tmpdir)
)
tenant.ensure_index_directory()

logger = AuditLogger(tenant)
logger.log_event("test_event", {"key": "value"})

# Check audit file was created
audit_file = tenant.index_path / "audit.jsonl"
assert audit_file.exists()

# Verify content
with open(audit_file) as f:
event = json.loads(f.readline())
assert event["event_type"] == "test_event"
assert event["tenant_id"] == "test123"
assert event["details"]["key"] == "value"

def test_log_request(self):
"""Should log MCP requests."""
with tempfile.TemporaryDirectory() as tmpdir:
tenant = TenantContext(
tenant_id="test123",
workspace_root=Path(tmpdir)
)
tenant.ensure_index_directory()

logger = AuditLogger(tenant)
logger.log_request("semantic_search", {"query": "test"}, allowed=True)

audit_file = tenant.index_path / "audit.jsonl"
with open(audit_file) as f:
event = json.loads(f.readline())
assert event["event_type"] == "mcp_request"
assert event["details"]["method"] == "semantic_search"

def test_log_validation_failure(self):
"""Should log validation failures."""
with tempfile.TemporaryDirectory() as tmpdir:
tenant = TenantContext(
tenant_id="test123",
workspace_root=Path(tmpdir)
)
tenant.ensure_index_directory()

logger = AuditLogger(tenant)
logger.log_validation_failure(
"blocked pattern",
{"method": "semantic_search"}
)

audit_file = tenant.index_path / "audit.jsonl"
with open(audit_file) as f:
event = json.loads(f.readline())
assert event["event_type"] == "validation_failure"
assert event["success"] is False

=============================================================================

Integration Tests

=============================================================================

class TestSecurityIntegration: """Integration tests for security scenarios."""

def test_multi_tenant_isolation(self):
"""Each tenant should have isolated index paths."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)

tenant_a = TenantContext(tenant_id="tenant_a", workspace_root=workspace)
tenant_b = TenantContext(tenant_id="tenant_b", workspace_root=workspace)

path_a = tenant_a.ensure_index_directory()
path_b = tenant_b.ensure_index_directory()

# Paths should be different
assert path_a != path_b

# Neither should be able to access the other's path
assert not str(path_a).startswith(str(path_b))
assert not str(path_b).startswith(str(path_a))

def test_injection_attack_vectors(self):
"""Common injection attack vectors should be blocked."""
validator = InputValidator(SecurityConfig())

attack_vectors = [
# JSON injection
'"}}, {"method": "execute"}',
'\\x00malicious',
# Path traversal
'../../../etc/passwd',
'..\\..\\..\\windows\\system32',
# Command injection (if somehow used)
'; rm -rf /',
'| cat /etc/passwd',
'$(whoami)',
'`whoami`',
]

for vector in attack_vectors:
message = {
"method": "test",
"params": {"query": vector}
}
valid, _ = validator.validate_message(message)
# Most should be blocked by our patterns
# Some may pass but won't do damage due to our architecture

if name == "main": pytest.main([file, "-v"])