Source code for arshai.workflows.workflow_config
from typing import Dict, Any, List, Type, Optional
from arshai.core.interfaces.iworkflow import IWorkflowConfig, IWorkflowOrchestrator, INode
from arshai.workflows.workflow_orchestrator import BaseWorkflowOrchestrator
from arshai.utils import get_logger
[docs]
class WorkflowConfig(IWorkflowConfig):
"""Base implementation of workflow configuration.
This implementation provides a foundation for workflow configuration where:
- The config creates and configures the workflow orchestrator
- The config defines the workflow structure (nodes and edges)
- The config provides routing logic for input
- Components are injected directly rather than through Settings
"""
[docs]
def __init__(
self,
debug_mode: bool = False,
**kwargs: Any
):
"""Initialize workflow configuration.
Args:
debug_mode: Whether to enable debug mode for verbose logging
**kwargs: Additional configuration options that subclasses can use
Note:
Subclasses should accept their required dependencies directly in their
constructors rather than relying on a Settings object. This follows
the three-layer architecture where developers have full control over
component instantiation.
Example:
class MyWorkflowConfig(WorkflowConfig):
def __init__(self, llm_client: ILLM, memory_manager: IMemoryManager, **kwargs):
super().__init__(**kwargs)
self.llm_client = llm_client
self.memory_manager = memory_manager
"""
self.debug_mode = debug_mode
self._kwargs = kwargs
self._logger = get_logger(__name__)
# Nodes and edges will be created in _configure_workflow
self.nodes: Dict[str, INode] = {}
self.edges: Dict[str, str] = {}
[docs]
def create_workflow(self) -> IWorkflowOrchestrator:
"""Create the workflow orchestrator (without configuration).
This method:
1. Creates a new workflow orchestrator
2. Returns the unconfigured orchestrator
Note: Call _configure_workflow(workflow) separately to configure it
Returns:
Unconfigured workflow orchestrator
"""
self._logger.debug("Creating workflow orchestrator")
# Create the workflow orchestrator
workflow = BaseWorkflowOrchestrator(debug_mode=self.debug_mode)
return workflow
def _configure_workflow(self, workflow: IWorkflowOrchestrator) -> None:
"""Configure the workflow with nodes, edges, and entry points (sync).
This method can be implemented by subclasses for sync configuration:
1. What nodes the workflow contains
2. How nodes are connected with edges
3. Entry points and routing logic
Args:
workflow: The workflow orchestrator to configure
"""
# Default implementation - subclasses can override for sync configuration
raise NotImplementedError("Subclasses must implement _configure_workflow or _configure_workflow_async")
async def _configure_workflow_async(self, workflow: IWorkflowOrchestrator) -> None:
"""Configure the workflow with nodes, edges, and entry points (async).
This method can be implemented by subclasses for async configuration:
1. What nodes the workflow contains (async)
2. How nodes are connected with edges
3. Entry points and routing logic
Args:
workflow: The workflow orchestrator to configure
"""
# Default implementation - subclasses can override for async configuration
raise NotImplementedError("Subclasses must implement _configure_workflow or _configure_workflow_async")
def _route_input(self, input_data: Dict[str, Any]) -> str:
"""Route to appropriate entry node based on input.
This method must be implemented by subclasses to define the routing logic
that determines which entry node to start with based on the input data.
Args:
input_data: The input data to route
Returns:
The name of the entry node to start with
"""
# This method should be overridden by subclasses
raise NotImplementedError("Subclasses must implement _route_input")
def _create_nodes(self) -> Dict[str, INode]:
"""Create all nodes for the workflow.
This method must be implemented by subclasses to create all the nodes
that will be used in the workflow.
Returns:
Dictionary mapping node names to node instances
"""
# This method should be overridden by subclasses
raise NotImplementedError("Subclasses must implement _create_nodes")
def _define_edges(self) -> Dict[str, str]:
"""Define the edges between nodes.
This method must be implemented by subclasses to define the edges
that connect nodes in the workflow.
Returns:
Dictionary mapping source node names to destination node names
"""
# This method should be overridden by subclasses
raise NotImplementedError("Subclasses must implement _define_edges")