Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify registration of error/output to status mappings #166

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions src/dispatch/status.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import Any, Callable, Dict, Type
from typing import Any, Callable, Dict, Type, Union

from dispatch.sdk.v1 import status_pb2 as status_pb

Expand Down Expand Up @@ -78,16 +78,18 @@ def __str__(self):
Status.NOT_FOUND.__doc__ = "An operation was performed on a non-existent resource"
Status.NOT_FOUND._proto = status_pb.STATUS_NOT_FOUND

_ERROR_TYPES: Dict[Type[Exception], Callable[[Exception], Status]] = {}
_OUTPUT_TYPES: Dict[Type[Any], Callable[[Any], Status]] = {}
_ERROR_TYPES: Dict[Type[Exception], Union[Status, Callable[[Exception], Status]]] = {}
_OUTPUT_TYPES: Dict[Type[Any], Union[Status, Callable[[Any], Status]]] = {}


def status_for_error(error: BaseException) -> Status:
"""Returns a Status that corresponds to the specified error."""
# See if the error matches one of the registered types.
handler = _find_handler(error, _ERROR_TYPES)
if handler is not None:
return handler(error)
status_or_handler = _find_status_or_handler(error, _ERROR_TYPES)
if status_or_handler is not None:
if isinstance(status_or_handler, Status):
return status_or_handler
return status_or_handler(error)
# If not, resort to standard error categorization.
#
# See https://docs.python.org/3/library/exceptions.html
Expand Down Expand Up @@ -120,28 +122,43 @@ def status_for_error(error: BaseException) -> Status:
def status_for_output(output: Any) -> Status:
"""Returns a Status that corresponds to the specified output value."""
# See if the output value matches one of the registered types.
handler = _find_handler(output, _OUTPUT_TYPES)
if handler is not None:
return handler(output)
status_or_handler = _find_status_or_handler(output, _OUTPUT_TYPES)
if status_or_handler is not None:
if isinstance(status_or_handler, Status):
return status_or_handler
return status_or_handler(output)

return Status.OK


def register_error_type(
error_type: Type[Exception], handler: Callable[[Exception], Status]
error_type: Type[Exception],
status_or_handler: Union[Status, Callable[[Exception], Status]],
):
"""Register an error type, and a handler which derives a Status from
errors of this type."""
_ERROR_TYPES[error_type] = handler
"""Register an error type to Status mapping.

The caller can either register a base exception and a handler, which
derives a Status from errors of this type. Or, if there's only one
exception to Status mapping to register, the caller can simply pass
the exception class and the associated Status.
"""
_ERROR_TYPES[error_type] = status_or_handler


def register_output_type(output_type: Type[Any], handler: Callable[[Any], Status]):
"""Register an output type, and a handler which derives a Status from
outputs of this type."""
_OUTPUT_TYPES[output_type] = handler
def register_output_type(
output_type: Type[Any], status_or_handler: Union[Status, Callable[[Any], Status]]
):
"""Register an output type to Status mapping.

The caller can either register a base class and a handler, which
derives a Status from other classes of this type. Or, if there's
only one output class to Status mapping to register, the caller can
simply pass the class and the associated Status.
"""
_OUTPUT_TYPES[output_type] = status_or_handler


def _find_handler(obj, types):
def _find_status_or_handler(obj, types):
for cls in type(obj).__mro__:
try:
return types[cls]
Expand Down
109 changes: 106 additions & 3 deletions tests/dispatch/test_status.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import unittest
from typing import Any

from dispatch import error
from dispatch.integrations.http import http_response_code_status
from dispatch.status import Status, status_for_error
from dispatch.status import (
Status,
register_error_type,
register_output_type,
status_for_error,
status_for_output,
)


class TestErrorStatus(unittest.TestCase):
Expand Down Expand Up @@ -56,13 +63,49 @@ class CustomError(Exception):
pass

def handler(error: Exception) -> Status:
assert isinstance(error, CustomError)
return Status.OK

from dispatch.status import register_error_type

register_error_type(CustomError, handler)
assert status_for_error(CustomError()) is Status.OK

def test_status_for_custom_error_with_base_handler(self):
class CustomBaseError(Exception):
pass

class CustomError(CustomBaseError):
pass

def handler(error: Exception) -> Status:
assert isinstance(error, CustomBaseError)
assert isinstance(error, CustomError)
return Status.TCP_ERROR

register_error_type(CustomBaseError, handler)
assert status_for_error(CustomError()) is Status.TCP_ERROR

def test_status_for_custom_error_with_status(self):
class CustomError(Exception):
pass

register_error_type(CustomError, Status.THROTTLED)
assert status_for_error(CustomError()) is Status.THROTTLED

def test_status_for_custom_error_with_base_status(self):
class CustomBaseError(Exception):
pass

class CustomError(CustomBaseError):
pass

class CustomError2(CustomBaseError):
pass

register_error_type(CustomBaseError, Status.THROTTLED)
register_error_type(CustomError2, Status.INVALID_ARGUMENT)
assert status_for_error(CustomError()) is Status.THROTTLED
assert status_for_error(CustomError2()) is Status.INVALID_ARGUMENT

def test_status_for_custom_timeout(self):
class CustomError(TimeoutError):
pass
Expand Down Expand Up @@ -90,6 +133,66 @@ def test_status_for_DispatchError(self):
assert status_for_error(error.NotFoundError()) is Status.NOT_FOUND
assert status_for_error(error.DispatchError()) is Status.UNSPECIFIED

def test_status_for_custom_output(self):
class CustomOutput:
pass

assert status_for_output(CustomOutput()) is Status.OK # default

def test_status_for_custom_output_with_handler(self):
class CustomOutput:
pass

def handler(output: Any) -> Status:
assert isinstance(output, CustomOutput)
return Status.DNS_ERROR

register_output_type(CustomOutput, handler)
assert status_for_output(CustomOutput()) is Status.DNS_ERROR

def test_status_for_custom_output_with_base_handler(self):
class CustomOutputBase:
pass

class CustomOutputError(CustomOutputBase):
pass

class CustomOutputSuccess(CustomOutputBase):
pass

def handler(output: Any) -> Status:
assert isinstance(output, CustomOutputBase)
if isinstance(output, CustomOutputError):
return Status.DNS_ERROR
assert isinstance(output, CustomOutputSuccess)
return Status.OK

register_output_type(CustomOutputBase, handler)
assert status_for_output(CustomOutputSuccess()) is Status.OK
assert status_for_output(CustomOutputError()) is Status.DNS_ERROR

def test_status_for_custom_output_with_status(self):
class CustomOutputBase:
pass

class CustomOutputChild1(CustomOutputBase):
pass

class CustomOutputChild2(CustomOutputBase):
pass

register_output_type(CustomOutputBase, Status.PERMISSION_DENIED)
register_output_type(CustomOutputChild1, Status.TCP_ERROR)
assert status_for_output(CustomOutputChild1()) is Status.TCP_ERROR
assert status_for_output(CustomOutputChild2()) is Status.PERMISSION_DENIED

def test_status_for_custom_output_with_base_status(self):
class CustomOutput(Exception):
pass

register_output_type(CustomOutput, Status.THROTTLED)
assert status_for_output(CustomOutput()) is Status.THROTTLED


class TestHTTPStatusCodes(unittest.TestCase):
def test_http_response_code_status_400(self):
Expand Down