#!/usr/bin/env python3 """ ADR-151 Context Graph Algorithms (Phase 5: CP-36, CP-37)
Implements seed node selection and BFS expansion algorithms for building task-specific context graphs from the knowledge graph.
Algorithms: - select_seed_nodes(): Choose starting nodes based on strategy - bfs_expand(): Breadth-first expansion with relevance scoring - compute_relevance_score(): Calculate node relevance to task
Created: 2026-02-03 Author: Claude (Opus 4.5) Track: J (Memory Intelligence) Task: J.25.2.3, J.25.2.4 (CP-36, CP-37) """
import json import logging import re import sqlite3 from collections import deque from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Tuple
logger = logging.getLogger(name)
=============================================================================
Data Classes
=============================================================================
@dataclass class GraphNode: """Represents a node in the context graph.""" id: str node_type: str name: str subtype: Optional[str] = None properties: Optional[Dict[str, Any]] = None relevance_score: float = 1.0 depth: int = 0 is_seed: bool = False token_estimate: int = 0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"id": self.id,
"node_type": self.node_type,
"name": self.name,
"subtype": self.subtype,
"properties": self.properties,
"relevance_score": self.relevance_score,
"depth": self.depth,
"is_seed": self.is_seed,
"token_estimate": self.token_estimate,
}
@dataclass class GraphEdge: """Represents an edge in the context graph.""" from_node: str to_node: str edge_type: str properties: Optional[Dict[str, Any]] = None weight: float = 1.0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"from_node": self.from_node,
"to_node": self.to_node,
"edge_type": self.edge_type,
"properties": self.properties,
"weight": self.weight,
}
@dataclass class ContextGraph: """Represents a task-specific context graph.""" id: str task_description: str nodes: Dict[str, GraphNode] = field(default_factory=dict) edges: List[GraphEdge] = field(default_factory=list) seed_nodes: List[str] = field(default_factory=list) seed_strategy: str = "anchor" token_budget: int = 4000 max_depth: int = 3 max_nodes: int = 128 policies_applied: List[Dict[str, Any]] = field(default_factory=list) phi_node_count: int = 0 # J.25.4.3: Count of nodes flagged as potential PHI
@property
def node_count(self) -> int:
"""Get number of nodes in graph."""
return len(self.nodes)
@property
def edge_count(self) -> int:
"""Get number of edges in graph."""
return len(self.edges)
@property
def tokens_estimated(self) -> int:
"""Get estimated total tokens."""
return sum(n.token_estimate for n in self.nodes.values())
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"id": self.id,
"task_description": self.task_description,
"nodes": {k: v.to_dict() for k, v in self.nodes.items()},
"edges": [e.to_dict() for e in self.edges],
"seed_nodes": self.seed_nodes,
"seed_strategy": self.seed_strategy,
"token_budget": self.token_budget,
"max_depth": self.max_depth,
"max_nodes": self.max_nodes,
"node_count": self.node_count,
"edge_count": self.edge_count,
"tokens_estimated": self.tokens_estimated,
"policies_applied": self.policies_applied,
"phi_node_count": self.phi_node_count,
}
=============================================================================
CP-36: Seed Node Selection Algorithm
=============================================================================
def select_seed_nodes( conn: sqlite3.Connection, task_description: str, strategy: str = "anchor", anchor_node_ids: Optional[List[str]] = None, node_types: Optional[List[str]] = None, max_seeds: int = 5, relevance_fn: Optional[Callable[[str, str], float]] = None, ) -> List[Tuple[str, float]]: """ CP-36: Select seed nodes for context graph expansion.
Strategies:
- anchor: Use provided anchor_node_ids directly
- semantic: FTS5 search on task_description, rank by relevance
- policy_first: Start from policy nodes, then expand
Args:
conn: Connection to org.db containing kg_nodes
task_description: Natural language task description
strategy: One of 'anchor', 'semantic', 'policy_first'
anchor_node_ids: Explicit node IDs to use as seeds (for 'anchor')
node_types: Filter seed nodes by type (optional)
max_seeds: Maximum number of seed nodes to return
relevance_fn: Optional custom relevance scoring function
Returns:
List of (node_id, relevance_score) tuples, sorted by relevance
"""
logger.info(f"Selecting seed nodes with strategy: {strategy}")
if strategy == "anchor":
return _select_anchor_seeds(conn, anchor_node_ids, max_seeds)
elif strategy == "semantic":
return _select_semantic_seeds(conn, task_description, node_types, max_seeds)
elif strategy == "policy_first":
return _select_policy_seeds(conn, task_description, max_seeds)
else:
logger.warning(f"Unknown strategy '{strategy}', falling back to semantic")
return _select_semantic_seeds(conn, task_description, node_types, max_seeds)
def _select_anchor_seeds( conn: sqlite3.Connection, anchor_node_ids: Optional[List[str]], max_seeds: int, ) -> List[Tuple[str, float]]: """ Select seeds from explicitly provided anchor node IDs. """ if not anchor_node_ids: logger.warning("No anchor_node_ids provided, returning empty seeds") return []
# Verify nodes exist and return with full relevance
seeds = []
for node_id in anchor_node_ids[:max_seeds]:
cursor = conn.execute(
"SELECT id FROM kg_nodes WHERE id = ?",
(node_id,)
)
if cursor.fetchone():
seeds.append((node_id, 1.0))
else:
logger.warning(f"Anchor node not found: {node_id}")
logger.info(f"Selected {len(seeds)} anchor seeds")
return seeds
def _select_semantic_seeds( conn: sqlite3.Connection, task_description: str, node_types: Optional[List[str]], max_seeds: int, ) -> List[Tuple[str, float]]: """ Select seeds using FTS5 semantic search on task description. """ # Tokenize task description for FTS5 query # Remove punctuation and create OR query tokens = re.findall(r'\b\w+\b', task_description.lower()) # Filter common words stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'to', 'of', 'and', 'or', 'in', 'on', 'for', 'with', 'as', 'at', 'by', 'from'} tokens = [t for t in tokens if t not in stop_words and len(t) > 2]
if not tokens:
logger.warning("No valid search tokens found in task description")
return []
# Build FTS5 query
fts_query = " OR ".join(tokens[:10]) # Limit tokens for performance
try:
# Query FTS5 index with BM25 ranking
query = """
SELECT n.id, n.node_type, n.name,
bm25(kg_nodes_fts) AS rank
FROM kg_nodes_fts f
JOIN kg_nodes n ON f.rowid = n.rowid
WHERE kg_nodes_fts MATCH ?
"""
params = [fts_query]
# Add node type filter if specified
if node_types:
placeholders = ",".join("?" * len(node_types))
query += f" AND n.node_type IN ({placeholders})"
params.extend(node_types)
query += " ORDER BY rank LIMIT ?"
params.append(max_seeds)
cursor = conn.execute(query, params)
rows = cursor.fetchall()
if not rows:
logger.info("FTS5 search returned no results, falling back to LIKE search")
return _select_fallback_seeds(conn, tokens, node_types, max_seeds)
# Normalize BM25 scores to 0-1 range
# BM25 returns negative values, more negative = more relevant
min_rank = min(r[3] for r in rows)
max_rank = max(r[3] for r in rows)
range_rank = max_rank - min_rank if max_rank != min_rank else 1.0
seeds = []
for row in rows:
node_id = row[0]
rank = row[3]
# Normalize: most negative becomes 1.0, least negative becomes lower
if range_rank != 0:
score = 1.0 - (rank - min_rank) / range_rank
else:
score = 1.0
seeds.append((node_id, round(score, 4)))
logger.info(f"Selected {len(seeds)} semantic seeds via FTS5")
return seeds
except sqlite3.OperationalError as e:
logger.warning(f"FTS5 query failed: {e}, falling back to LIKE search")
return _select_fallback_seeds(conn, tokens, node_types, max_seeds)
def _select_fallback_seeds( conn: sqlite3.Connection, tokens: List[str], node_types: Optional[List[str]], max_seeds: int, ) -> List[Tuple[str, float]]: """ Fallback seed selection using LIKE queries when FTS5 fails. """ seeds = [] seen_ids = set()
for token in tokens[:5]: # Limit tokens for performance
query = "SELECT id, name FROM kg_nodes WHERE name LIKE ?"
params = [f"%{token}%"]
if node_types:
placeholders = ",".join("?" * len(node_types))
query += f" AND node_type IN ({placeholders})"
params.extend(node_types)
query += " LIMIT ?"
params.append(max_seeds - len(seeds))
cursor = conn.execute(query, params)
for row in cursor:
if row[0] not in seen_ids:
seen_ids.add(row[0])
# Score based on position in token list
score = 0.9 - (tokens.index(token) * 0.1)
seeds.append((row[0], max(0.5, score)))
if len(seeds) >= max_seeds:
break
logger.info(f"Selected {len(seeds)} seeds via LIKE fallback")
return seeds[:max_seeds]
def _select_policy_seeds( conn: sqlite3.Connection, task_description: str, max_seeds: int, ) -> List[Tuple[str, float]]: """ Select seeds starting from policy nodes (for governance-focused graphs). """ # First get policy nodes cursor = conn.execute(""" SELECT id, name, properties FROM kg_nodes WHERE node_type = 'policy' ORDER BY updated_at DESC LIMIT ? """, (max_seeds,))
policies = cursor.fetchall()
if not policies:
logger.info("No policy nodes found, falling back to ADR nodes")
cursor = conn.execute("""
SELECT id, name, properties
FROM kg_nodes
WHERE node_type = 'adr'
ORDER BY updated_at DESC
LIMIT ?
""", (max_seeds,))
policies = cursor.fetchall()
seeds = [(row[0], 1.0) for row in policies]
logger.info(f"Selected {len(seeds)} policy/ADR seeds")
return seeds
=============================================================================
CP-37: BFS Expansion with Relevance Scoring
=============================================================================
def bfs_expand( conn: sqlite3.Connection, seed_nodes: List[Tuple[str, float]], max_depth: int = 3, max_nodes: int = 128, relevance_threshold: float = 0.3, edge_types: Optional[List[str]] = None, decay_factor: float = 0.8, ) -> ContextGraph: """ CP-37: BFS expansion from seed nodes with relevance scoring.
Expands outward from seed nodes following edges, applying
relevance decay at each hop.
Args:
conn: Connection to org.db containing kg_nodes and kg_edges
seed_nodes: List of (node_id, relevance_score) from select_seed_nodes
max_depth: Maximum BFS depth (hops from seed)
max_nodes: Maximum nodes to include
relevance_threshold: Minimum relevance to include a node
edge_types: Filter edges by type (None = all edges)
decay_factor: Relevance multiplier per hop (0.8 = 20% decay)
Returns:
ContextGraph with expanded nodes and edges
"""
logger.info(f"Starting BFS expansion from {len(seed_nodes)} seeds")
logger.info(f"Parameters: max_depth={max_depth}, max_nodes={max_nodes}, threshold={relevance_threshold}")
# Initialize graph
import hashlib
import time
graph_id = f"cg:{int(time.time())}:{hashlib.md5(str(seed_nodes).encode()).hexdigest()[:8]}"
graph = ContextGraph(
id=graph_id,
task_description="",
seed_nodes=[s[0] for s in seed_nodes],
max_depth=max_depth,
max_nodes=max_nodes,
)
# Initialize BFS queue: (node_id, depth, relevance_score)
queue: deque = deque()
visited: Set[str] = set()
# Add seed nodes
for node_id, relevance in seed_nodes:
node = _fetch_node(conn, node_id)
if node:
node.relevance_score = relevance
node.depth = 0
node.is_seed = True
node.token_estimate = _estimate_tokens(node)
graph.nodes[node_id] = node
visited.add(node_id)
queue.append((node_id, 0, relevance))
# BFS expansion
while queue and len(graph.nodes) < max_nodes:
current_id, current_depth, current_relevance = queue.popleft()
# Stop if we've reached max depth
if current_depth >= max_depth:
continue
# Get neighbors via edges
neighbors = _get_neighbors(conn, current_id, edge_types)
for neighbor_id, edge_type, edge_properties in neighbors:
if neighbor_id in visited:
continue
# Calculate relevance with decay
neighbor_relevance = current_relevance * decay_factor
# Apply edge weight if available
if edge_properties:
try:
props = json.loads(edge_properties) if isinstance(edge_properties, str) else edge_properties
if "weight" in props:
neighbor_relevance *= props["weight"]
except (json.JSONDecodeError, TypeError):
pass
# Skip if below threshold
if neighbor_relevance < relevance_threshold:
continue
# Fetch and add neighbor node
neighbor_node = _fetch_node(conn, neighbor_id)
if neighbor_node:
neighbor_node.relevance_score = round(neighbor_relevance, 4)
neighbor_node.depth = current_depth + 1
neighbor_node.token_estimate = _estimate_tokens(neighbor_node)
graph.nodes[neighbor_id] = neighbor_node
visited.add(neighbor_id)
# Add edge
graph.edges.append(GraphEdge(
from_node=current_id,
to_node=neighbor_id,
edge_type=edge_type,
properties=edge_properties if isinstance(edge_properties, dict) else None,
))
# Add to queue for further expansion
queue.append((neighbor_id, current_depth + 1, neighbor_relevance))
if len(graph.nodes) >= max_nodes:
logger.info(f"Reached max_nodes limit: {max_nodes}")
break
logger.info(f"BFS expansion complete: {graph.node_count} nodes, {graph.edge_count} edges")
return graph
def _fetch_node(conn: sqlite3.Connection, node_id: str) -> Optional[GraphNode]: """Fetch a node from kg_nodes by ID.""" cursor = conn.execute(""" SELECT id, node_type, name, subtype, properties FROM kg_nodes WHERE id = ? """, (node_id,))
row = cursor.fetchone()
if not row:
return None
properties = None
if row[4]:
try:
properties = json.loads(row[4])
except json.JSONDecodeError:
properties = None
return GraphNode(
id=row[0],
node_type=row[1],
name=row[2],
subtype=row[3],
properties=properties,
)
def _get_neighbors( conn: sqlite3.Connection, node_id: str, edge_types: Optional[List[str]] = None, ) -> List[Tuple[str, str, Optional[Dict]]]: """Get all neighbors of a node via edges.""" neighbors = []
# Outgoing edges
query = "SELECT to_node, edge_type, properties FROM kg_edges WHERE from_node = ?"
params = [node_id]
if edge_types:
placeholders = ",".join("?" * len(edge_types))
query += f" AND edge_type IN ({placeholders})"
params.extend(edge_types)
cursor = conn.execute(query, params)
for row in cursor:
neighbors.append((row[0], row[1], row[2]))
# Incoming edges (for bidirectional traversal)
query = "SELECT from_node, edge_type, properties FROM kg_edges WHERE to_node = ?"
params = [node_id]
if edge_types:
placeholders = ",".join("?" * len(edge_types))
query += f" AND edge_type IN ({placeholders})"
params.extend(edge_types)
cursor = conn.execute(query, params)
for row in cursor:
neighbors.append((row[0], row[1], row[2]))
return neighbors
def _estimate_tokens(node: GraphNode) -> int: """ Estimate token count for a node when serialized.
Rough estimation: ~4 chars per token
"""
text_len = len(node.name)
if node.properties:
text_len += len(json.dumps(node.properties))
return max(10, text_len // 4)
=============================================================================
Relevance Scoring
=============================================================================
def compute_relevance_score( node: GraphNode, task_description: str, context_nodes: Optional[List[GraphNode]] = None, ) -> float: """ Compute relevance score for a node relative to task.
Factors:
- Name similarity to task keywords
- Node type importance
- Property richness
- Connection to other high-relevance nodes
Args:
node: The node to score
task_description: Task/query description
context_nodes: Other nodes in context for relative scoring
Returns:
Relevance score between 0.0 and 1.0
"""
score = 0.0
# Tokenize task
task_tokens = set(re.findall(r'\b\w+\b', task_description.lower()))
name_tokens = set(re.findall(r'\b\w+\b', node.name.lower()))
# Name overlap
if task_tokens and name_tokens:
overlap = len(task_tokens & name_tokens)
score += min(0.4, overlap * 0.1)
# Node type importance
type_weights = {
"decision": 0.25,
"adr": 0.25,
"policy": 0.2,
"error_solution": 0.2,
"component": 0.15,
"skill_learning": 0.15,
"function": 0.1,
"file": 0.1,
"session": 0.05,
"track": 0.1,
}
score += type_weights.get(node.node_type, 0.1)
# Property richness
if node.properties:
prop_count = len(node.properties)
score += min(0.15, prop_count * 0.03)
# Normalize to 0-1
return min(1.0, round(score, 4))