Skip to main content

scripts-recursive-agent-chain

#!/usr/bin/env python3 """ CODITECT Recursive Agent Chain System (ADR-077)

Implements RLM's recursive self-calling pattern enabling agents to spawn sub-agents via agent_query() for large context decomposition and nested reasoning.

Based on analysis of submodules/rlm (alexzhang13/rlm). """

from dataclasses import dataclass, field from typing import List, Optional, Callable, Any, Dict, TYPE_CHECKING from enum import Enum import asyncio import time

if TYPE_CHECKING: from scripts.core.usage_tracking import UsageTracker, UsageSummary

class AgentQueryType(Enum): """Type of agent query.""" SINGLE = "single" BATCH = "batch"

@dataclass class AgentCallContext: """Context for a single agent call in the chain.

Attributes:
depth: Current recursion depth (0 = root)
parent_task_id: PILOT plan task ID of parent
child_index: Index of this child in batch (if applicable)
context_snapshot: First 1000 chars of context for debugging
"""
depth: int
parent_task_id: str
child_index: Optional[int] = None
context_snapshot: Optional[str] = None

@dataclass class AgentCallResult: """Result from an agent call.

Attributes:
content: Response content from agent
usage: Token usage summary
depth: Recursion depth at which this executed
duration_ms: Execution time in milliseconds
child_calls: Number of sub-agent calls made
"""
content: str
usage: Optional[Any] # UsageSummary
depth: int
duration_ms: float
child_calls: int = 0

@dataclass class ChunkingConfig: """Configuration for context chunking.

Attributes:
max_tokens: Maximum tokens per chunk (default: 50000)
overlap_tokens: Overlap between chunks (default: 500)
strategy: Chunking strategy - "token", "structure", or "semantic"
"""
max_tokens: int = 50000
overlap_tokens: int = 500
strategy: str = "token" # token, structure, semantic

class RecursionDepthError(Exception): """Raised when recursion depth limit is exceeded.""" pass

class RecursiveAgentChain: """ Recursive agent chain enabling nested agent calls.

Based on RLM's llm_query() pattern, adapted for CODITECT agents.
Enables agents to spawn sub-agents for large context decomposition.

Attributes:
MAX_DEPTH: Maximum recursion depth (default: 5)
MAX_BATCH_SIZE: Maximum parallel queries (default: 10)

Usage:
chain = RecursiveAgentChain(model_client)
result = await chain.execute(
task="Analyze codebase",
context=large_codebase,
task_id="A.1.1"
)

The executed agent can call:
- agent_query(prompt, agent_type) - Single sub-agent call
- agent_query_batch(prompts, agent_type) - Parallel sub-agent calls
- chunk_context(context, config) - Split large context
- FINAL(answer) - Mark completion
"""

MAX_DEPTH: int = 5
MAX_BATCH_SIZE: int = 10

def __init__(
self,
model_client: Any,
usage_tracker: Optional["UsageTracker"] = None
):
"""
Initialize recursive agent chain.

Args:
model_client: Model client for LLM calls (must have async complete())
usage_tracker: Optional UsageTracker for token tracking
"""
self.model_client = model_client
self.call_stack: List[AgentCallContext] = []
self.total_child_calls: int = 0

# Initialize usage tracker
if usage_tracker is not None:
self.usage_tracker = usage_tracker
else:
from scripts.core.usage_tracking import UsageTracker
self.usage_tracker = UsageTracker()

@property
def current_depth(self) -> int:
"""Current recursion depth (0 = not yet started)."""
return len(self.call_stack)

async def execute(
self,
task: str,
context: str,
task_id: str,
agent_type: str = "general"
) -> AgentCallResult:
"""
Execute agent with recursive capabilities.

The agent can call:
- agent_query(prompt, agent_type) - Single sub-agent call
- agent_query_batch(prompts, agent_type) - Parallel sub-agent calls
- chunk_context(context, config) - Split large context

Args:
task: Task description for the agent
context: Context to analyze (can be large)
task_id: PILOT plan task ID for tracking
agent_type: Type of agent to use

Returns:
AgentCallResult with content, usage, and metadata

Raises:
RecursionDepthError: If max recursion depth exceeded
"""
if self.current_depth >= self.MAX_DEPTH:
raise RecursionDepthError(f"Max depth {self.MAX_DEPTH} exceeded")

# Push call context
ctx = AgentCallContext(
depth=self.current_depth,
parent_task_id=task_id,
context_snapshot=context[:1000] if context else None
)
self.call_stack.append(ctx)

try:
# Create execution environment with injected functions
env = self._create_execution_environment(task_id)

# Execute agent
result = await self._execute_agent(task, context, agent_type, env)

return AgentCallResult(
content=result["content"],
usage=result["usage"],
depth=ctx.depth,
duration_ms=result["duration_ms"],
child_calls=self.total_child_calls
)
finally:
self.call_stack.pop()

def _create_execution_environment(self, parent_task_id: str) -> Dict[str, Any]:
"""Create environment with agent_query functions.

Args:
parent_task_id: Task ID for generating child task IDs

Returns:
Dictionary with injected functions for agent use
"""

async def agent_query(prompt: str, agent_type: str = "general") -> str:
"""
Call a sub-agent with a prompt.

This function is injected into the agent's execution environment,
enabling recursive agent chains.

Args:
prompt: Task prompt for sub-agent
agent_type: Type of sub-agent to invoke

Returns:
Response content from sub-agent
"""
self.total_child_calls += 1
child_task_id = f"{parent_task_id}.sub{self.total_child_calls}"

result = await self.execute(
task=prompt,
context="", # Sub-agents get no parent context (isolation)
task_id=child_task_id,
agent_type=agent_type
)
return result.content

async def agent_query_batch(
prompts: List[str],
agent_type: str = "general"
) -> List[str]:
"""
Call multiple sub-agents in parallel.

Useful for processing chunks concurrently.

Args:
prompts: List of prompts for parallel processing
agent_type: Type of sub-agents to invoke

Returns:
List of response contents

Raises:
ValueError: If batch size exceeds MAX_BATCH_SIZE
"""
if len(prompts) > self.MAX_BATCH_SIZE:
raise ValueError(
f"Batch size {len(prompts)} exceeds max {self.MAX_BATCH_SIZE}"
)

tasks = [agent_query(p, agent_type) for p in prompts]
return await asyncio.gather(*tasks)

def chunk_context(
context: str,
config: Optional[ChunkingConfig] = None
) -> List[str]:
"""
Split large context into manageable chunks.

Strategies:
- token: Split by token count with overlap
- structure: Split by code structure (functions, classes)
- semantic: Split by semantic boundaries

Args:
context: Large context to chunk
config: Chunking configuration (defaults to token strategy)

Returns:
List of context chunks
"""
cfg = config or ChunkingConfig()
return self._chunk_by_strategy(context, cfg)

def final_marker(answer: Any) -> Dict[str, Any]:
"""
Mark task completion with final answer.

Args:
answer: Final answer to return

Returns:
Completion marker dict
"""
return {"__final__": True, "content": answer}

return {
"agent_query": agent_query,
"agent_query_batch": agent_query_batch,
"chunk_context": chunk_context,
"FINAL": final_marker
}

def _chunk_by_strategy(
self,
context: str,
config: ChunkingConfig
) -> List[str]:
"""Chunk context based on configured strategy.

Args:
context: Context to chunk
config: Chunking configuration

Returns:
List of chunks

Raises:
ValueError: If strategy is unknown
"""
if config.strategy == "token":
return self._chunk_by_tokens(
context,
config.max_tokens,
config.overlap_tokens
)
elif config.strategy == "structure":
return self._chunk_by_structure(context)
elif config.strategy == "semantic":
return self._chunk_by_semantic(context, config.max_tokens)
else:
raise ValueError(f"Unknown chunking strategy: {config.strategy}")

def _chunk_by_tokens(
self,
context: str,
max_tokens: int,
overlap: int
) -> List[str]:
"""Split by approximate token count with overlap.

Uses rough estimate of 4 characters per token.

Args:
context: Context to chunk
max_tokens: Maximum tokens per chunk
overlap: Number of overlap tokens between chunks

Returns:
List of chunks with overlap
"""
if not context:
return []

# Rough estimate: 4 chars per token
chars_per_token = 4
max_chars = max_tokens * chars_per_token
overlap_chars = overlap * chars_per_token

chunks = []
start = 0

while start < len(context):
end = start + max_chars
chunk = context[start:end]
chunks.append(chunk)

# Move start forward, accounting for overlap
if end >= len(context):
break
start = end - overlap_chars

return chunks

def _chunk_by_structure(self, context: str) -> List[str]:
"""Split by code structure (functions, classes).

Currently uses simple double-newline splitting.
Future: Use AST parsing for more accurate splitting.

Args:
context: Code context to chunk

Returns:
List of structural chunks
"""
# Simple implementation: split on double newlines
# This roughly corresponds to function/class boundaries
chunks = [c for c in context.split("\n\n") if c.strip()]
return chunks if chunks else [context] if context else []

def _chunk_by_semantic(
self,
context: str,
max_tokens: int
) -> List[str]:
"""Split by semantic boundaries.

Currently falls back to token chunking.
Future: Use embeddings for semantic splitting.

Args:
context: Context to chunk
max_tokens: Maximum tokens per chunk

Returns:
List of semantic chunks
"""
# Placeholder: fall back to token chunking
# Future: use embeddings to find semantic boundaries
return self._chunk_by_tokens(context, max_tokens, 500)

async def _execute_agent(
self,
task: str,
context: str,
agent_type: str,
env: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute agent with environment.

Args:
task: Task description
context: Context for agent
agent_type: Type of agent
env: Execution environment with injected functions

Returns:
Dict with content, usage, and duration_ms
"""
start = time.time()

# Build prompt with available functions documentation
prompt = f"""You are a CODITECT agent with recursive capabilities.

Task

{task}

Context

The context is available as a variable. You can:

  1. Examine it: len(context), context[:1000]
  2. Chunk it: chunks = chunk_context(context, ChunkingConfig(max_tokens=50000))
  3. Query sub-agents: result = await agent_query(prompt, "analyst")
  4. Batch queries: results = await agent_query_batch(prompts, "analyst")
  5. Complete: FINAL(your_answer)

Context Length

{len(context)} characters

Available Context Preview

{context[:2000] if context else "No context provided"}

Generate code to accomplish the task. Use FINAL(answer) when done. """

    # Call model
response = await self.model_client.complete(prompt, agent_type)

# Track usage if available
if response.get("usage"):
self.usage_tracker.record_usage(
model=response.get("model", "unknown"),
provider=response.get("provider", "unknown"),
input_tokens=response["usage"].get("input_tokens", 0),
output_tokens=response["usage"].get("output_tokens", 0)
)

duration = (time.time() - start) * 1000

return {
"content": response.get("content", ""),
"usage": self.usage_tracker.get_summary(),
"duration_ms": duration
}

Convenience functions

def create_chain(model_client: Any) -> RecursiveAgentChain: """Create a RecursiveAgentChain with default settings.

Args:
model_client: Model client for LLM calls

Returns:
Configured RecursiveAgentChain
"""
return RecursiveAgentChain(model_client)

async def execute_recursive( model_client: Any, task: str, context: str, task_id: str, agent_type: str = "general" ) -> AgentCallResult: """Execute a recursive agent task.

Args:
model_client: Model client for LLM calls
task: Task description
context: Context to analyze
task_id: PILOT plan task ID
agent_type: Type of agent

Returns:
AgentCallResult with content and usage
"""
chain = RecursiveAgentChain(model_client)
return await chain.execute(task, context, task_id, agent_type)