diff --git a/README.md b/README.md index 7e43836..6d9d225 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ See more in https://github.com/indeedeng/iwf#what-is-iwf - [x] Start workflow API - [x] Executing `wait_until`/`execute` APIs and completing workflow - [x] Parallel execution of multiple states +- [x] GetWorkflowResultsWithWait API - [x] StateOption: WaitUntil(optional)/Execute API timeout and retry policy - [x] Get workflow with wait API - [x] Timer command @@ -45,7 +46,6 @@ See more in https://github.com/indeedeng/iwf#what-is-iwf - [x] Support execute API failure policy - [ ] Support workflow RPC - [ ] Signal command -- [ ] Signal workflow API ## Future -- the advanced features that already supported in server. Contributions are welcome to implement them in this SDK! - [ ] Atomic conditional complete workflow by checking signal/internal channel emptiness @@ -54,18 +54,18 @@ See more in https://github.com/indeedeng/iwf#what-is-iwf - [ ] Describe workflow API - [ ] TryGetWorkflowResults API - [ ] Consume N messages in a single command -- [ ] SearchAttribute: keyword -- [ ] New search attribute types: Double, Bool, Datetime, Keyword array, Text -- [ ] Workflow start options: initial search attributes -- [ ] Search workflow API - [ ] Reset workflow API - [ ] Skip timer API for testing/operation - [ ] Decider trigger type: any command combination - [ ] Failing workflow with results - [ ] Wait_until API failure policy - [ ] Caching on persistence -- [ ] Get workflow DataAttributes/SearchAttributes API - [ ] StateExecutionLocal +- [ ] SearchAttribute: keyword +- [ ] New search attribute types: Double, Bool, Datetime, Keyword array, Text +- [ ] Workflow start options: initial search attributes +- [ ] Search workflow API +- [ ] Get workflow DataAttributes/SearchAttributes API ### Running iwf-server locally diff --git a/iwf/client.py b/iwf/client.py index 89ff806..7884211 100644 --- a/iwf/client.py +++ b/iwf/client.py @@ -1,6 +1,8 @@ -from typing import Any, Optional, Type, TypeVar +import inspect +from typing import Any, Callable, Optional, Type, TypeVar from iwf.client_options import ClientOptions +from iwf.errors import InvalidArgumentError from iwf.registry import Registry from iwf.stop_workflow_options import StopWorkflowOptions from iwf.unregistered_client import UnregisteredClient, UnregisteredWorkflowOptions @@ -12,6 +14,14 @@ T = TypeVar("T") +def get_workflow_type_by_rpc_method(meth) -> str: + if inspect.ismethod(meth): + return inspect.getmro(meth.__self__.__class__)[0].__name__ + if inspect.isfunction(meth): + return meth.__qualname__.split(".", 1)[0].rsplit(".", 1)[0] + raise InvalidArgumentError(f"method {meth} is not a RPC method") + + class Client: def __init__(self, registry: Registry, options: Optional[ClientOptions] = None): self._registry = registry @@ -99,3 +109,25 @@ def stop_workflow( options: Optional[StopWorkflowOptions] = None, ): return self._unregistered_client.stop_workflow(workflow_id, "", options) + + def invoke_rpc( + self, + workflow_id: str, + rpc: Callable, # this can be a function: RPCWorkflow.rpc_method or a method: workflow_instance.rpc_method + input: Any = None, + return_type_hint: Optional[Type[T]] = None, + ) -> Optional[T]: + wf_type = get_workflow_type_by_rpc_method(rpc) + rpc_name = rpc.__name__ + rpc_info = self._registry.get_rpc_infos(wf_type)[rpc_name] + + return self._unregistered_client.invoke_rpc( + input=input, + workflow_id=workflow_id, + workflow_run_id="", + rpc_name=rpc_name, + timeout_seconds=rpc_info.timeout_seconds, + data_attribute_policy=rpc_info.data_attribute_loading_policy, + all_defined_search_attribute_types=[], + return_type_hint=return_type_hint, + ) diff --git a/iwf/communication.py b/iwf/communication.py index 2f8073e..bb6c960 100644 --- a/iwf/communication.py +++ b/iwf/communication.py @@ -1,15 +1,17 @@ -from typing import Any, Optional +from typing import Any, Optional, Union from iwf_api.models import EncodedObject, InterStateChannelPublishing from iwf.errors import WorkflowDefinitionError from iwf.object_encoder import ObjectEncoder +from iwf.state_movement import StateMovement class Communication: _type_store: dict[str, Optional[type]] _object_encoder: ObjectEncoder _to_publish_internal_channel: dict[str, list[EncodedObject]] + _state_movements: list[StateMovement] def __init__( self, type_store: dict[str, Optional[type]], object_encoder: ObjectEncoder @@ -17,8 +19,19 @@ def __init__( self._object_encoder = object_encoder self._type_store = type_store self._to_publish_internal_channel = {} + self._state_movements = [] - def publish_to_internal_channel(self, channel_name: str, value: Any): + def trigger_state_execution(self, state: Union[str, type], state_input: Any = None): + """ + + Args: + state: the workflowState TODO the type hint should be type[WorkflowState] + state_input: the input of the state + """ + movement = StateMovement.create(state, state_input) + self._state_movements.append(movement) + + def publish_to_internal_channel(self, channel_name: str, value: Any = None): if channel_name not in self._type_store: raise WorkflowDefinitionError( f"InternalChannel channel_name is not defined {channel_name}" @@ -45,3 +58,6 @@ def get_to_publishing_internal_channel(self) -> list[InterStateChannelPublishing for val in vals: pubs.append(InterStateChannelPublishing(name, val)) return pubs + + def get_to_trigger_state_movements(self) -> list[StateMovement]: + return self._state_movements diff --git a/iwf/errors.py b/iwf/errors.py index 423428b..3912d01 100644 --- a/iwf/errors.py +++ b/iwf/errors.py @@ -37,6 +37,14 @@ class WorkflowStillRunningError(ClientSideError): pass +class WorkflowRPCExecutionError(ClientSideError): + pass + + +class WorkflowRPCAcquiringLockFailure(ClientSideError): + pass + + class WorkflowAlreadyStartedError(ClientSideError): pass diff --git a/iwf/registry.py b/iwf/registry.py index 12ae53b..00ab857 100644 --- a/iwf/registry.py +++ b/iwf/registry.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Callable, Optional from iwf.communication_schema import CommunicationMethodType from iwf.errors import InvalidArgumentError, WorkflowDefinitionError from iwf.persistence_schema import PersistenceFieldType +from iwf.rpc import RPCInfo from iwf.workflow import ObjectWorkflow, get_workflow_type from iwf.workflow_state import WorkflowState, get_state_id @@ -13,6 +14,7 @@ class Registry: _state_store: dict[str, dict[str, WorkflowState]] _internal_channel_type_store: dict[str, dict[str, Optional[type]]] _data_attribute_types: dict[str, dict[str, Optional[type]]] + _rpc_infos: dict[str, dict[str, RPCInfo]] def __init__(self): self._workflow_store = dict() @@ -20,12 +22,14 @@ def __init__(self): self._state_store = dict() self._internal_channel_type_store = dict() self._data_attribute_types = dict() + self._rpc_infos = dict() def add_workflow(self, wf: ObjectWorkflow): self._register_workflow_type(wf) self._register_workflow_state(wf) self._register_internal_channels(wf) self._register_data_attributes(wf) + self._register_workflow_rpcs(wf) def add_workflows(self, *wfs: ObjectWorkflow): for wf in wfs: @@ -62,6 +66,9 @@ def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]: def get_data_attribute_types(self, wf_type: str) -> dict[str, Optional[type]]: return self._data_attribute_types[wf_type] + def get_rpc_infos(self, wf_type: str) -> dict[str, RPCInfo]: + return self._rpc_infos[wf_type] + def _register_workflow_type(self, wf: ObjectWorkflow): wf_type = get_workflow_type(wf) if wf_type in self._workflow_store: @@ -104,3 +111,24 @@ def _register_workflow_state(self, wf): starting_state = state_def.state self._state_store[wf_type] = state_map self._starting_state_store[wf_type] = starting_state + + @staticmethod + def _is_decorated_by_rpc(func: Callable): + return getattr(func, "_is_iwf_rpc", False) + + @staticmethod + def _get_rpc_info(func: Callable): + info = getattr(func, "_rpc_info") + assert isinstance(info, RPCInfo) + # NOTE: we have to override the method here so that it's associated the object + info.method_func = func + return info + + def _register_workflow_rpcs(self, wf): + wf_type = get_workflow_type(wf) + rpc_infos = {} + for method_name in dir(wf): + method = getattr(wf, method_name) + if callable(method) and self._is_decorated_by_rpc(method): + rpc_infos[method_name] = self._get_rpc_info(method) + self._rpc_infos[wf_type] = rpc_infos diff --git a/iwf/rpc.py b/iwf/rpc.py new file mode 100644 index 0000000..c561aff --- /dev/null +++ b/iwf/rpc.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from functools import wraps +from inspect import signature +from typing import Any, Callable, Optional + +from iwf_api.models import PersistenceLoadingPolicy, PersistenceLoadingType + +from iwf.errors import WorkflowDefinitionError + + +@dataclass +class RPCInfo: + method_func: Callable + timeout_seconds: int + input_type: Optional[type] = None + data_attribute_loading_policy: Optional[PersistenceLoadingPolicy] = None + params_order: Optional[ + list + ] = None # store this so that the rpc can be invoked with correct parameters + + +rpc_definition_err = WorkflowDefinitionError( + "an RPC must have at most 5 params: self, context:WorkflowContext, input:Any, persistence:Persistence, " + 'communication:Communication, where input can be any type as long as the param name is "input" ' +) + + +def rpc( + timeout_seconds: int = 10, + data_attribute_loading_policy: Optional[PersistenceLoadingPolicy] = None, +): + def decorator(func): + # preserve the properties of the original function. + @wraps(func) + def wrapper(*args, **kwargs): + # TODO need to add type hint for decorated method + return func(*args, **kwargs) + + wrapper._is_iwf_rpc = True + rpc_info = RPCInfo( + method_func=func, + timeout_seconds=timeout_seconds, + data_attribute_loading_policy=data_attribute_loading_policy, + ) + params = signature(func).parameters + + from inspect import _empty # ignored. + from iwf.persistence import Persistence + from iwf.workflow_context import WorkflowContext + from iwf.communication import Communication + + valid_param_types = { + _empty: True, + Any: True, + Persistence: True, + WorkflowContext: True, + Communication: True, + } + need_persistence = False + params_order = [] + if len(params) > 5: + raise rpc_definition_err + + for k, v in params.items(): + if k != "self": + params_order.append(v.annotation) + if k == "input": + rpc_info.input_type = v.annotation + continue + if v.annotation == Persistence: + need_persistence = True + if v.annotation not in valid_param_types: + raise rpc_definition_err + if not need_persistence: + rpc_info.data_attribute_loading_policy = PersistenceLoadingPolicy( + persistence_loading_type=PersistenceLoadingType.LOAD_NONE + ) + rpc_info.params_order = params_order + wrapper._rpc_info = rpc_info + return wrapper + + return decorator diff --git a/iwf/tests/test_rpc.py b/iwf/tests/test_rpc.py new file mode 100644 index 0000000..9b7e67f --- /dev/null +++ b/iwf/tests/test_rpc.py @@ -0,0 +1,123 @@ +import inspect +import time +import unittest + +from iwf.client import Client +from iwf.command_request import CommandRequest, InternalChannelCommand +from iwf.command_results import CommandResults +from iwf.communication import Communication +from iwf.communication_schema import CommunicationMethod, CommunicationSchema +from iwf.persistence import Persistence +from iwf.persistence_schema import PersistenceField, PersistenceSchema +from iwf.rpc import rpc +from iwf.state_decision import StateDecision +from iwf.state_schema import StateSchema +from iwf.tests.worker_server import registry +from iwf.workflow import ObjectWorkflow +from iwf.workflow_context import WorkflowContext +from iwf.workflow_state import T, WorkflowState + +test_data_attribute = "test-1" +channel_name = "test-2" + + +class WaitState(WorkflowState[None]): + def wait_until( + self, + ctx: WorkflowContext, + input: T, + persistence: Persistence, + communication: Communication, + ) -> CommandRequest: + return CommandRequest.for_all_command_completed( + InternalChannelCommand.by_name(channel_name) + ) + + def execute( + self, + ctx: WorkflowContext, + input: T, + command_results: CommandResults, + persistence: Persistence, + communication: Communication, + ) -> StateDecision: + return StateDecision.single_next_state(EndState) + + +class EndState(WorkflowState[None]): + def execute( + self, + ctx: WorkflowContext, + input: T, + command_results: CommandResults, + persistence: Persistence, + communication: Communication, + ) -> StateDecision: + return StateDecision.graceful_complete_workflow("done") + + +class RPCWorkflow(ObjectWorkflow): + def get_persistence_schema(self) -> PersistenceSchema: + return PersistenceSchema.create( + PersistenceField.data_attribute_def(test_data_attribute, int) + ) + + def get_communication_schema(self) -> CommunicationSchema: + return CommunicationSchema.create( + CommunicationMethod.internal_channel_def(channel_name, type(None)) + ) + + def get_workflow_states(self) -> StateSchema: + return StateSchema.no_starting_state(WaitState(), EndState()) + + @rpc(timeout_seconds=100) + def test_simple_rpc(self): + return 123 + + @rpc() + def test_rpc_persistence_write( + self, + input: int, + persistence: Persistence, + ): + persistence.set_data_attribute(test_data_attribute, input) + + @rpc() + def test_rpc_persistence_read(self, pers: Persistence): + return pers.get_data_attribute(test_data_attribute) + + @rpc() + def test_rpc_trigger_state(self, com: Communication): + com.trigger_state_execution(WaitState) + + @rpc() + def test_rpc_publish_channel(self, com: Communication): + com.publish_to_internal_channel(channel_name) + + +class TestRPCs(unittest.TestCase): + @classmethod + def setUpClass(cls): + wf = RPCWorkflow() + registry.add_workflow(wf) + cls.client = Client(registry) + + def test_simple_rpc(self): + wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" + self.client.start_workflow(RPCWorkflow, wf_id, 10) + output = self.client.invoke_rpc(wf_id, RPCWorkflow.test_simple_rpc) + assert output == 123 + wf = RPCWorkflow() + output = self.client.invoke_rpc(wf_id, wf.test_simple_rpc) + assert output == 123 + + def test_complicated_rpc(self): + wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" + self.client.start_workflow(RPCWorkflow, wf_id, 10) + self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_persistence_write, 100) + res = self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_persistence_read) + assert res == 100 + self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_trigger_state) + self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_publish_channel) + result = self.client.get_simple_workflow_result_with_wait(wf_id, str) + assert result == "done" diff --git a/iwf/tests/worker_server.py b/iwf/tests/worker_server.py index 99f8648..ec0b1e6 100644 --- a/iwf/tests/worker_server.py +++ b/iwf/tests/worker_server.py @@ -2,7 +2,11 @@ from threading import Thread from flask import Flask, request -from iwf_api.models import WorkflowStateExecuteRequest, WorkflowStateWaitUntilRequest +from iwf_api.models import ( + WorkflowStateExecuteRequest, + WorkflowStateWaitUntilRequest, + WorkflowWorkerRpcRequest, +) from iwf.registry import Registry from iwf.worker_service import ( @@ -37,9 +41,16 @@ def handle_execute(): return resp.to_dict() +@_flask_app.route(WorkerService.api_path_workflow_worker_rpc, methods=["POST"]) +def handle_rpc(): + req = WorkflowWorkerRpcRequest.from_dict(request.json) + resp = _worker_service.handle_workflow_worker_rpc(req) + return resp.to_dict() + + @_flask_app.errorhandler(Exception) def internal_error(exception): - # TODO: how to print to std in a different thread?? + # TODO: how to print to std ?? response = exception.get_response() # replace the body with JSON response.data = traceback.format_exc() diff --git a/iwf/unregistered_client.py b/iwf/unregistered_client.py index 826f776..1cd3a51 100644 --- a/iwf/unregistered_client.py +++ b/iwf/unregistered_client.py @@ -7,6 +7,7 @@ from iwf_api.api.default import ( post_api_v1_workflow_dataobjects_get, post_api_v1_workflow_reset, + post_api_v1_workflow_rpc, post_api_v1_workflow_search, post_api_v1_workflow_searchattributes_get, post_api_v1_workflow_signal, @@ -19,6 +20,7 @@ EncodedObject, ErrorResponse, IDReusePolicy, + PersistenceLoadingPolicy, SearchAttribute, SearchAttributeKeyAndType, WorkflowConfig, @@ -29,6 +31,8 @@ WorkflowGetSearchAttributesResponse, WorkflowResetRequest, WorkflowRetryPolicy, + WorkflowRpcRequest, + WorkflowRpcResponse, WorkflowSearchRequest, WorkflowSearchResponse, WorkflowSignalRequest, @@ -44,6 +48,8 @@ from iwf.client_options import ClientOptions from iwf.errors import ( WorkflowDefinitionError, + WorkflowRPCAcquiringLockFailure, + WorkflowRPCExecutionError, WorkflowStillRunningError, parse_unexpected_error, process_http_error, @@ -195,6 +201,51 @@ def get_simple_workflow_result_with_wait( type_hint, ) + def invoke_rpc( + self, + input: Any, + workflow_id: str, + workflow_run_id: str, + rpc_name: str, + timeout_seconds: int, + data_attribute_policy: Optional[PersistenceLoadingPolicy], + all_defined_search_attribute_types: list[SearchAttributeKeyAndType], + return_type_hint: Optional[Type[T]] = None, + ) -> Optional[T]: + request = WorkflowRpcRequest( + input_=self.client_options.object_encoder.encode(input), + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + rpc_name=rpc_name, + timeout_seconds=timeout_seconds, + search_attributes=all_defined_search_attribute_types, + ) + if data_attribute_policy is not None: + request.data_attributes_loading_policy = data_attribute_policy + + try: + response = post_api_v1_workflow_rpc.sync_detailed( + client=self.api_client, + json_body=request, + ) + except errors.UnexpectedStatus as err: + err_resp = parse_unexpected_error(err) + if err.status_code == 420: + raise WorkflowRPCExecutionError(err.status_code, err_resp) + if err.status_code == 450: + raise WorkflowRPCAcquiringLockFailure(err.status_code, err_resp) + else: + raise RuntimeError(f"unknown error code {err.status_code}") + + if response.status_code != http.HTTPStatus.OK: + assert isinstance(response.parsed, ErrorResponse) + raise process_http_error(response.status_code, response.parsed) + assert isinstance(response.parsed, WorkflowRpcResponse) + return self.client_options.object_encoder.decode( + response.parsed.output, + return_type_hint, + ) + def signal_workflow( self, workflow_id: str, diff --git a/iwf/worker_service.py b/iwf/worker_service.py index b948656..19675a8 100644 --- a/iwf/worker_service.py +++ b/iwf/worker_service.py @@ -8,6 +8,8 @@ WorkflowStateExecuteResponse, WorkflowStateWaitUntilRequest, WorkflowStateWaitUntilResponse, + WorkflowWorkerRpcRequest, + WorkflowWorkerRpcResponse, ) from iwf_api.types import Unset @@ -17,9 +19,9 @@ from iwf.object_encoder import ObjectEncoder from iwf.persistence import Persistence from iwf.registry import Registry -from iwf.state_decision import _to_idl_state_decision +from iwf.state_decision import StateDecision, _to_idl_state_decision from iwf.utils.iwf_typing import assert_not_unset, unset_to_none -from iwf.workflow_context import _from_idl_context +from iwf.workflow_context import WorkflowContext, _from_idl_context from iwf.workflow_state import get_input_type @@ -38,6 +40,7 @@ class WorkerService: api_path_workflow_state_execute: typing.ClassVar[ str ] = "/api/v1/workflowState/decide" + api_path_workflow_worker_rpc: typing.ClassVar[str] = "/api/v1/workflowWorker/rpc" def __init__( self, registry: Registry, options: WorkerOptions = default_worker_options @@ -45,6 +48,71 @@ def __init__( self._registry = registry self._options = options + def handle_workflow_worker_rpc( + self, + request: WorkflowWorkerRpcRequest, + ) -> WorkflowWorkerRpcResponse: + wf_type = request.workflow_type + rpc_info = self._registry.get_rpc_infos(wf_type)[request.rpc_name] + + internal_channel_types = self._registry.get_internal_channel_types(wf_type) + data_attributes_types = self._registry.get_data_attribute_types(wf_type) + + context = _from_idl_context(request.context) + _input = self._options.object_encoder.decode( + unset_to_none(request.input_), rpc_info.input_type + ) + + current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {} + if not isinstance(request.data_attributes, Unset): + current_data_attributes = { + assert_not_unset(attr.key): unset_to_none(attr.value) + for attr in request.data_attributes + } + + persistence = Persistence( + data_attributes_types, self._options.object_encoder, current_data_attributes + ) + communication = Communication( + internal_channel_types, self._options.object_encoder + ) + params: typing.Any = [] + if rpc_info.params_order is not None: + for param_type in rpc_info.params_order: + if param_type == Persistence: + params.append(persistence) + elif param_type == Communication: + params.append(communication) + elif param_type == WorkflowContext: + params.append(context) + else: + params.append(_input) + + output = rpc_info.method_func(*params) + + pubs = communication.get_to_publishing_internal_channel() + response = WorkflowWorkerRpcResponse( + output=self._options.object_encoder.encode(output) + ) + + if len(pubs) > 0: + response.publish_to_inter_state_channel = pubs + if len(persistence.get_updated_values_to_return()) > 0: + response.upsert_data_attributes = [ + KeyValue(k, v) + for (k, v) in persistence.get_updated_values_to_return().items() + ] + if len(communication.get_to_trigger_state_movements()) > 0: + movements = communication.get_to_trigger_state_movements() + decision = StateDecision.multi_next_states(*movements) + response.state_decision = _to_idl_state_decision( + decision, + wf_type, + self._registry, + self._options.object_encoder, + ) + return response + def handle_workflow_state_wait_until( self, request: WorkflowStateWaitUntilRequest,