Source code for arshai.extensions.hooks

"""
Hook system for extending Arshai behavior.
"""

from typing import Callable, Dict, List, Any, Optional
from enum import Enum
import asyncio
from dataclasses import dataclass


[docs] class HookType(Enum): """Types of hooks available in the system.""" # Agent hooks BEFORE_AGENT_PROCESS = "before_agent_process" AFTER_AGENT_PROCESS = "after_agent_process" # Workflow hooks BEFORE_WORKFLOW_START = "before_workflow_start" AFTER_WORKFLOW_END = "after_workflow_end" BEFORE_NODE_EXECUTE = "before_node_execute" AFTER_NODE_EXECUTE = "after_node_execute" # Memory hooks BEFORE_MEMORY_SAVE = "before_memory_save" AFTER_MEMORY_SAVE = "after_memory_save" BEFORE_MEMORY_RETRIEVE = "before_memory_retrieve" AFTER_MEMORY_RETRIEVE = "after_memory_retrieve" # Tool hooks BEFORE_TOOL_EXECUTE = "before_tool_execute" AFTER_TOOL_EXECUTE = "after_tool_execute" # LLM hooks BEFORE_LLM_CALL = "before_llm_call" AFTER_LLM_CALL = "after_llm_call"
[docs] @dataclass class HookContext: """Context passed to hook functions.""" hook_type: HookType data: Dict[str, Any] metadata: Dict[str, Any] = None def __post_init__(self): if self.metadata is None: self.metadata = {}
[docs] class Hook: """ Represents a hook that can be registered in the system. """
[docs] def __init__( self, name: str, hook_type: HookType, callback: Callable, priority: int = 0, enabled: bool = True ): """ Initialize a hook. Args: name: Unique name for the hook hook_type: Type of hook (when it should be called) callback: Function to call when hook is triggered priority: Priority for execution order (higher = earlier) enabled: Whether the hook is enabled """ self.name = name self.hook_type = hook_type self.callback = callback self.priority = priority self.enabled = enabled
[docs] async def execute(self, context: HookContext) -> Any: """Execute the hook callback.""" if not self.enabled: return None if asyncio.iscoroutinefunction(self.callback): return await self.callback(context) else: return self.callback(context)
[docs] class HookManager: """ Manages hooks for the Arshai framework. """
[docs] def __init__(self): self._hooks: Dict[HookType, List[Hook]] = { hook_type: [] for hook_type in HookType }
[docs] def register_hook(self, hook: Hook) -> None: """ Register a hook. Args: hook: The hook to register """ hooks_list = self._hooks[hook.hook_type] # Check for duplicate names if any(h.name == hook.name for h in hooks_list): raise ValueError(f"Hook '{hook.name}' already registered for {hook.hook_type}") # Add hook and sort by priority hooks_list.append(hook) hooks_list.sort(key=lambda h: h.priority, reverse=True)
[docs] def unregister_hook(self, name: str, hook_type: Optional[HookType] = None) -> None: """ Unregister a hook. Args: name: Name of the hook to unregister hook_type: Type of hook (if None, removes from all types) """ if hook_type: self._hooks[hook_type] = [ h for h in self._hooks[hook_type] if h.name != name ] else: for hook_list in self._hooks.values(): hook_list[:] = [h for h in hook_list if h.name != name]
[docs] async def execute_hooks( self, hook_type: HookType, data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None ) -> List[Any]: """ Execute all hooks of a given type. Args: hook_type: Type of hooks to execute data: Data to pass to hooks metadata: Additional metadata Returns: List of results from hook executions """ context = HookContext( hook_type=hook_type, data=data, metadata=metadata or {} ) results = [] for hook in self._hooks[hook_type]: if hook.enabled: try: result = await hook.execute(context) results.append(result) # Allow hooks to modify context data if isinstance(result, dict) and "modified_data" in result: context.data.update(result["modified_data"]) except Exception as e: # Log error but don't stop execution print(f"Error in hook '{hook.name}': {e}") # In production, use proper logging return results
[docs] def get_hooks(self, hook_type: HookType) -> List[Hook]: """Get all hooks of a given type.""" return self._hooks[hook_type].copy()
[docs] def enable_hook(self, name: str) -> None: """Enable a hook by name.""" for hook_list in self._hooks.values(): for hook in hook_list: if hook.name == name: hook.enabled = True
[docs] def disable_hook(self, name: str) -> None: """Disable a hook by name.""" for hook_list in self._hooks.values(): for hook in hook_list: if hook.name == name: hook.enabled = False
# Global hook manager _global_hook_manager = HookManager()
[docs] def get_hook_manager() -> HookManager: """Get the global hook manager.""" return _global_hook_manager
# Decorator for easy hook registration
[docs] def hook(hook_type: HookType, name: Optional[str] = None, priority: int = 0): """ Decorator for registering a function as a hook. Example: @hook(HookType.BEFORE_AGENT_PROCESS, priority=10) def my_hook(context: HookContext): print(f"Processing: {context.data}") """ def decorator(func: Callable): hook_name = name or func.__name__ hook_instance = Hook( name=hook_name, hook_type=hook_type, callback=func, priority=priority ) get_hook_manager().register_hook(hook_instance) return func return decorator