background/task_manager.py
Status
Context
The current situation requires a decision because:
- Requirement 1
- Constraint 2
- Need 3
Accepted | YYYY-MM-DD
background/task_manager.py
from typing import Dict, Any, Optional, List import asyncio from datetime import datetime import logging from uuid import UUID, uuid4 from enum import Enum from dataclasses import dataclass import json
class TaskStatus(Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" RETRYING = "retrying"
@dataclass class TaskMetadata: task_id: UUID task_type: str created_at: datetime started_at: Optional[datetime] completed_at: Optional[datetime] status: TaskStatus retries: int max_retries: int error: Optional[str]
class TaskManager: def init(self, db_connection, max_workers: int = 5): self.db = db_connection self.max_workers = max_workers self.processing_queue = asyncio.Queue() self.active_tasks: Dict[UUID, asyncio.Task] = {} self.logger = logging.getLogger(name)
async def initialize(self):
"""Initialize task manager and start workers"""
self.logger.info("Initializing task manager")
# Start worker tasks
self.workers = [
asyncio.create_task(self._worker(i))
for i in range(self.max_workers)
]
# Recover any pending tasks from database
await self._recover_pending_tasks()
async def _recover_pending_tasks(self):
"""Recover pending tasks from database"""
async with self.db.transaction() as conn:
pending_tasks = await conn.fetch("""
SELECT * FROM processing_jobs
WHERE status IN ('pending', 'processing')
AND attempts < max_attempts
ORDER BY created_at ASC
""")
for task in pending_tasks:
await self.processing_queue.put({
'task_id': task['job_uuid'],
'task_type': task['job_type'],
'params': task['parameters']
})
async def submit_task(
self,
task_type: str,
parameters: Dict[str, Any],
max_retries: int = 3
) -> UUID:
"""Submit a new task for processing"""
task_id = uuid4()
# Store task in database
async with self.db.transaction() as conn:
await conn.execute("""
INSERT INTO processing_jobs (
job_uuid, job_type, parameters,
status, max_attempts, created_at
) VALUES ($1, $2, $3, $4, $5, CURRENT_TIMESTAMP)
""", task_id, task_type, json.dumps(parameters),
TaskStatus.PENDING.value, max_retries)
# Add task to processing queue
await self.processing_queue.put({
'task_id': task_id,
'task_type': task_type,
'params': parameters
})
return task_id
async def _worker(self, worker_id: int):
"""Worker process that handles tasks"""
self.logger.info(f"Starting worker {worker_id}")
while True:
try:
# Get task from queue
task = await self.processing_queue.get()
task_id = task['task_id']
try:
# Update task status to processing
await self._update_task_status(
task_id,
TaskStatus.PROCESSING
)
# Process task based on type
result = await self._process_task(task)
# Mark task as completed
await self._update_task_status(
task_id,
TaskStatus.COMPLETED,
result=result
)
except Exception as e:
self.logger.error(
f"Error processing task {task_id}: {str(e)}"
)
# Handle task failure
await self._handle_task_failure(task_id, str(e))
finally:
self.processing_queue.task_done()
except Exception as e:
self.logger.error(
f"Worker {worker_id} encountered error: {str(e)}"
)
await asyncio.sleep(1) # Prevent tight error loop
async def _process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
"""Process a task based on its type"""
task_type = task['task_type']
params = task['params']
if task_type == 'document_processing':
return await self._process_document_task(params)
elif task_type == 'embedding_generation':
return await self._process_embedding_task(params)
elif task_type == 'relationship_creation':
return await self._process_relationship_task(params)
else:
raise ValueError(f"Unknown task type: {task_type}")
async def _process_document_task(
self,
params: Dict[str, Any]
) -> Dict[str, Any]:
"""Process a document chunking task"""
doc_uuid = params['doc_uuid']
content = params['content']
chunk_size = params.get('chunk_size', 1000)
overlap = params.get('overlap', 100)
# Create chunks with overlap
chunks = self._create_chunks(content, chunk_size, overlap)
# Store chunks in database
async with self.db.transaction() as conn:
for i, chunk in enumerate(chunks):
await conn.execute("""
INSERT INTO chunks (
doc_uuid, sequence_num, content,
start_offset, end_offset
) VALUES ($1, $2, $3, $4, $5)
""", doc_uuid, i, chunk['content'],
chunk['start'], chunk['end'])
# Submit embedding generation tasks for chunks
for chunk in chunks:
await self.submit_task(
'embedding_generation',
{
'chunk_uuid': chunk['chunk_uuid'],
'content': chunk['content']
}
)
return {
'doc_uuid': doc_uuid,
'chunks_created': len(chunks)
}
def _create_chunks(
self,
content: str,
chunk_size: int,
overlap: int
) -> List[Dict[str, Any]]:
"""Create chunks from content with overlap"""
chunks = []
start = 0
while start < len(content):
# Calculate end position
end = start + chunk_size
if end < len(content):
# Find next space to avoid breaking words
while end < len(content) and not content[end].isspace():
end += 1
else:
end = len(content)
# Create chunk
chunk_content = content[start:end]
chunks.append({
'chunk_uuid': uuid4(),
'content': chunk_content,
'start': start,
'end': end
})
# Move start position for next chunk
start = end - overlap
# Ensure we don't start in the middle of a word
if start > 0:
while start < len(content) and not content[start].isspace():
start += 1
start = min(start + 1, len(content))
return chunks
async def _process_embedding_task(
self,
params: Dict[str, Any]
) -> Dict[str, Any]:
"""Process embedding generation for a chunk"""
chunk_uuid = params['chunk_uuid']
content = params['content']
# Generate embedding
embedding = await self.embedding_model.embed_text(content)
# Store embedding in database
async with self.db.transaction() as conn:
await conn.execute("""
UPDATE chunks
SET embedding = $1
WHERE chunk_uuid = $2
""", embedding.tolist(), chunk_uuid)
return {
'chunk_uuid': chunk_uuid,
'embedding_size': len(embedding)
}
async def _update_task_status(
self,
task_id: UUID,
status: TaskStatus,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
):
"""Update task status in database"""
async with self.db.transaction() as conn:
await conn.execute("""
UPDATE processing_jobs
SET status = $1,
result = $2,
error_message = $3,
updated_at = CURRENT_TIMESTAMP,
completed_at = CASE
WHEN $1 IN ('completed', 'failed')
THEN CURRENT_TIMESTAMP
ELSE completed_at
END
WHERE job_uuid = $4
""", status.value, json.dumps(result) if result else None,
error, task_id)
async def _handle_task_failure(self, task_id: UUID, error: str):
"""Handle task failure and retry logic"""
async with self.db.transaction() as conn:
# Get current task state
task = await conn.fetchrow("""
SELECT * FROM processing_jobs
WHERE job_uuid = $1
""", task_id)
attempts = task['attempts'] + 1
max_attempts = task['max_attempts']
if attempts < max_attempts:
# Update for retry
await conn.execute("""
UPDATE processing_jobs
SET status = $1,
attempts = $2,
error_message = $3,
updated_at = CURRENT_TIMESTAMP
WHERE job_uuid = $4
""", TaskStatus.RETRYING.value, attempts,
error, task_id)
# Re-queue task
await self.processing_queue.put({
'task_id': task_id,
'task_type': task['job_type'],
'params': json.loads(task['parameters'])
})
else:
# Mark as failed
await conn.execute("""
UPDATE processing_jobs
SET status = $1,
attempts = $2,
error_message = $3,
updated_at = CURRENT_TIMESTAMP,
completed_at = CURRENT_TIMESTAMP
WHERE job_uuid = $4
""", TaskStatus.FAILED.value, attempts,
error, task_id)
async def get_task_status(self, task_id: UUID) -> Dict[str, Any]:
"""Get current task status"""
async with self.db.transaction() as conn:
task = await conn.fetchrow("""
SELECT
job_uuid,
job_type,
status,
attempts,
max_attempts,
error_message,
created_at,
updated_at,
completed_at,
result
FROM processing_jobs
WHERE job_uuid = $1
""", task_id)
return dict(task) if task else None
async def shutdown(self):
"""Shutdown task manager gracefully"""
self.logger.info("Shutting down task manager")
# Cancel all worker tasks
for worker in self.workers:
worker.cancel()
# Wait for workers to complete
await asyncio.gather(*self.workers, return_exceptions=True)
# Clear queue
while not self.processing_queue.empty():
await self.processing_queue.get()
self.processing_queue.task_done()