diff --git a/.github/workflows/lint-code.yml b/.github/workflows/lint-code.yml index a4861559d..17b267d05 100644 --- a/.github/workflows/lint-code.yml +++ b/.github/workflows/lint-code.yml @@ -35,7 +35,7 @@ jobs: python -m pip install -r setup_requirements.txt - name: Check Formatting run: tox -e fmt - - name: Run pylint + - name: Linting run: tox -e lint - name: Setup Graphviz # `graphviz/dot` is required for the import checker uses: ts-graphviz/setup-graphviz@v1 diff --git a/.github/workflows/publish-library.yml b/.github/workflows/publish-library.yml index 9574feaf5..5e7d3957a 100644 --- a/.github/workflows/publish-library.yml +++ b/.github/workflows/publish-library.yml @@ -27,6 +27,7 @@ jobs: uses: actions/setup-python@v3 - name: Build and check package run: | + pip install tox tox -e build,twinecheck - name: Upload package if: github.event_name == 'release' diff --git a/.gitignore b/.gitignore index d3b9755d9..cbafc1609 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ _build/ /caikit/_version.py +# Compiled pb2s +*_pb2.py +*_pb2_grpc.py diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 5c7b4676e..000000000 --- a/.pylintrc +++ /dev/null @@ -1,646 +0,0 @@ -[MAIN] - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Clear in-memory caches upon conclusion of linting. Useful if running pylint -# in a server-like mode. -clear-cache-post-run=no - -# Load and enable all available extensions. Use --list-extensions to see a list -# all available extensions. -#enable-all-extensions= - -# In error mode, messages with a category besides ERROR or FATAL are -# suppressed, and no reports are done by default. Error mode is compatible with -# disabling specific errors. -#errors-only= - -# Always return a 0 (non-error) status code, even if lint errors are found. -# This is primarily useful in continuous integration scripts. -#exit-zero= - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-allow-list= - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. (This is an alternative name to extension-pkg-allow-list -# for backward compatibility.) -extension-pkg-whitelist= - -# Return non-zero exit code if any of these messages/categories are detected, -# even if score is above --fail-under value. Syntax same as enable. Messages -# specified are enabled, while categories only check already-enabled messages. -fail-on= - -# Specify a score threshold under which the program will exit with error. -fail-under=10 - -# Interpret the stdin as a python script, whose filename needs to be passed as -# the module_or_package argument. -#from-stdin= - -# Files or directories to be skipped. They should be base names, not paths. -ignore=CVS,protobufs - -# Add files or directories matching the regular expressions patterns to the -# ignore-list. The regex matches against paths and can be in Posix or Windows -# format. Because '\\' represents the directory delimiter on Windows systems, -# it can't be used as an escape character. -ignore-paths= - -# Files or directories matching the regular expression patterns are skipped. -# The regex matches against base names, not paths. The default value ignores -# Emacs file locks -ignore-patterns=^\.# - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use, and will cap the count on Windows to -# avoid hangs. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Minimum Python version to use for version dependent checks. Will default to -# the version used to run pylint. -py-version=3.9 - -# Discover python modules and packages in the file system subtree. -recursive=no - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - -# In verbose mode, extra non-checker-related info will be displayed. -#verbose= - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. If left empty, argument names will be checked with the set -# naming style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. If left empty, attribute names will be checked with the set naming -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. If left empty, class attribute names will be checked -# with the set naming style. -#class-attribute-rgx= - -# Naming style matching correct class constant names. -class-const-naming-style=UPPER_CASE - -# Regular expression matching correct class constant names. Overrides class- -# const-naming-style. If left empty, class constant names will be checked with -# the set naming style. -#class-const-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. If left empty, class names will be checked with the set naming style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. If left empty, constant names will be checked with the set naming -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. If left empty, function names will be checked with the set -# naming style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. If left empty, inline iteration names will be checked -# with the set naming style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. If left empty, method names will be checked with the set naming style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. If left empty, module names will be checked with the set naming style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Regular expression matching correct type variable names. If left empty, type -# variable names will be checked with the set naming style. -#typevar-rgx= - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. If left empty, variable names will be checked with the set -# naming style. -#variable-rgx= - - -[CLASSES] - -# Warn about protected attribute access inside special methods -check-protected-access-in-special-methods=no - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[DESIGN] - -# List of regular expressions of class ancestor names to ignore when counting -# public methods (see R0903) -exclude-too-few-public-methods= - -# List of qualified class names to ignore when counting class parents (see -# R0901) -ignored-parents= - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when caught. -overgeneral-exceptions=builtins.BaseException,builtins.Exception - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1100 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow explicit reexports by alias from a package __init__. -allow-reexport-from-package=no - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules= - -# Output a graph (.gv or any supported image format) of external dependencies -# to the given file (report RP0402 must not be disabled). -ext-import-graph= - -# Output a graph (.gv or any supported image format) of all (i.e. internal and -# external) dependencies to the given file (report RP0402 must not be -# disabled). -import-graph= - -# Output a graph (.gv or any supported image format) of internal dependencies -# to the given file (report RP0402 must not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, -# UNDEFINED. -confidence=HIGH, - CONTROL_FLOW, - INFERENCE, - INFERENCE_FAILURE, - UNDEFINED - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then re-enable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - # Added messages - use-symbolic-message-instead, - invalid-name, - missing-class-docstring, - missing-module-docstring, - missing-function-docstring, - consider-using-f-string, - inconsistent-return-statements, - no-member, - too-many-arguments, - too-many-locals, - too-many-branches, - too-many-statements, - cyclic-import, - too-few-public-methods, - protected-access, - fixme, - logging-format-interpolation, - logging-too-many-args, - attribute-defined-outside-init, - abstract-method, - pointless-statement, - wrong-import-order - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[METHOD_ARGS] - -# List of qualified names (i.e., library.method) which require a timeout -# parameter e.g. 'requests.api.get,requests.api.post' -timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -notes-rgx= - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit,argparse.parse_error - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'fatal', 'error', 'warning', 'refactor', -# 'convention', and 'info' which contain the number of messages in each -# category, as well as 'statement' which is the total number of statements -# analyzed. This score is used by the global evaluation report (RP0004). -evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=yes - -# Activate the evaluation score. -score=yes - - -[SIMILARITIES] - -# Comments are removed from the similarity computation -ignore-comments=yes - -# Docstrings are removed from the similarity computation -ignore-docstrings=yes - -# Imports are removed from the similarity computation -ignore-imports=yes - -# Signatures are removed from the similarity computation -ignore-signatures=yes - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the 'python-enchant' package. -spelling-dict= - -# List of comma separated words that should be considered directives if they -# appear at the beginning of a comment and should not be checked. -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of symbolic message names to ignore for Mixin members. -ignored-checks-for-mixins=no-member, - not-async-context-manager, - not-context-manager, - attribute-defined-outside-init - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# Regex pattern to define which classes are considered mixins. -mixin-class-rgx=.*[Mm]ixin - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of names allowed to shadow builtins -allowed-redefined-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/caikit/config/config.yml b/caikit/config/config.yml index 8dbe055cd..434e1f038 100644 --- a/caikit/config/config.yml +++ b/caikit/config/config.yml @@ -87,10 +87,15 @@ data_streams: ### Runtime configurations runtime: - # The runtime library (or libraries) whose models we want to serve using Caikit Runtime. This should - # be a snake case string, e.g., caikit_nlp or caikit_cv. - library: sample_lib # TODO: replace with libraries below when runtime can support multiple libraries + # The runtime library (or libraries) whose models we want to serve using + # Caikit Runtime. This should be a snake case string, e.g., caikit_nlp or + # caikit_cv. + # TODO: replace with libraries below when runtime can support multiple libraries + library: sample_lib local_models_dir: models + # When loading models at boot, wait for the loads to complete before booting + # the server + wait_for_initial_model_loads: true # If enabled, the models in local_models_dir will be periodically sync'ed # with the in-memory models. New models that are not in-memory that are @@ -103,10 +108,17 @@ runtime: # that fail due to partially uploaded model artifacts when the load is # initiated. lazy_load_retries: 2 + # Amount of time to watch for file updates to detect if a model is being + # written. Models that are being written will not be prematurely loaded. Use + # zero or a negative value or None to disable. + lazy_load_write_detection_period_seconds: 0.05 # Number of threads to make available for load jobs (null => all) load_threads: null + # Number of threads for both the http and grpc server to share for servicing requests + server_thread_pool_size: 100 + # TLS configs tls: server: @@ -119,8 +131,6 @@ runtime: grpc: enabled: true port: 8085 - # Number of workers with which we will run the gRPC server - server_thread_pool_size: 5 # gRPC Server shutdown grace period # the server shuts down immediately when stopped, # and all RPCs active at the end of the grace period are aborted @@ -130,6 +140,10 @@ runtime: # Additional server options as key/value pairs # CITE: https://github.com/grpc/grpc/blob/master/include/grpc/impl/channel_arg_names.h#L22 options: {} + # Legacy config for setting thread pool size. See runtime.server_thread_pool_size instead + server_thread_pool_size: null + # Timeout for health probe to receive a response + probe_timeout: null # Configuration for the http server http: @@ -139,6 +153,11 @@ runtime: route_prefix: api/v1 # Maximum number of seconds to wait for graceful shutdown server_shutdown_grace_period_seconds: 5 + # Timeout for health probe to receive a response + probe_timeout: 0.01 + # Additional uvicorn server configuration + # CITE: https://github.com/encode/uvicorn/blob/master/uvicorn/config.py#L188 + server_config: {} # Configuration for the metrics server metrics: @@ -195,6 +214,12 @@ runtime: output_dir: training_output save_with_id: true + # Version details to retrieve + version_info: + python_packages: + all: false + runtime_image: "" + inference_plugin: model_mesh: # Model Mesh specific versions diff --git a/caikit/core/__init__.py b/caikit/core/__init__.py index b49200af8..98025c8ec 100644 --- a/caikit/core/__init__.py +++ b/caikit/core/__init__.py @@ -18,7 +18,6 @@ # the import order cannot adhere to the linter here because we must do things like # disable warnings, initialize the JVM and configure logging in a specific order -# pylint: disable=wrong-import-position,wrong-import-order # NOTE: There are cyclic imports due to the "import *"s here, when modules then # "import core" @@ -39,7 +38,7 @@ from .toolkit import * # Configure the global model wrangling functions -MODEL_MANAGER = ModelManager() +MODEL_MANAGER = ModelManager() # noqa: F405 extract = MODEL_MANAGER.extract load = MODEL_MANAGER.load resolve_and_load = MODEL_MANAGER.resolve_and_load diff --git a/caikit/core/data_model/base.py b/caikit/core/data_model/base.py index 5b169c8a7..d5ddd472d 100644 --- a/caikit/core/data_model/base.py +++ b/caikit/core/data_model/base.py @@ -19,7 +19,18 @@ from dataclasses import dataclass from enum import Enum from io import IOBase -from typing import Any, Dict, List, Optional, Type, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import base64 import datetime import json @@ -37,16 +48,36 @@ from ..exceptions import error_handler from . import enums, json_dict, timestamp +# if TYPE_CHECKING: # TODO: uncommenting this breaks `tox -e imports` because of a circular import +# # Local +# from caikit.core.data_model.data_backends import DataModelBackendBase +# from caikit.interfaces.common.data_model.file import File + # metaclass-generated field members cannot be detected by pylint # pylint: disable=no-member # pylint: disable=too-many-lines log = alog.use_channel("DATAM") -error = error_handler.get(log) +error: Callable[..., NoReturn] = error_handler.get(log) class _DataBaseMetaClass(type): + fields: Tuple + full_name: str + fields_enum_map: Dict # {} + fields_enum_rev: Dict # {} + _fields_oneofs_map: Dict # {} + _fields_to_oneof: Dict # {} + _fields_map: Tuple # () + _fields_message: Tuple # () + _fields_message_repeated: Tuple # () + _fields_enum: Tuple # () + _fields_enum_repeated: Tuple # () + _fields_primitive: Tuple # () + _fields_primitive_repeated: Tuple # () + _proto_class: ClassVar[Type[ProtoMessageType]] + """Meta class for all structures in the data model.""" # store a registry of all classes that use this metaclass, i.e., @@ -136,7 +167,7 @@ def __new__(mcs, name, bases, attrs): # Otherwise, we need to get the fields from a "special" attribute else: fields = attrs.pop(mcs._FWD_DECL_FIELDS, None) - log.debug4( + log.debug4( # type: ignore "Using dataclass forward declaration fields %s for %s", fields, name ) error.value_check( @@ -170,7 +201,7 @@ def __new__(mcs, name, bases, attrs): return instance @classmethod - def parse_proto_descriptor(mcs, cls): + def parse_proto_descriptor(mcs, cls): # pyright: ignore[reportSelfClsParameterName] """Encapsulate the logic for parsing the protobuf descriptor here. This allows the parsing to be done as a post-process after metaclass initialization @@ -304,17 +335,17 @@ def parse_proto_descriptor(mcs, cls): # If there is not already an __init__ function defined, make one current_init = cls.__init__ if current_init is None or current_init is DataBase.__init__: - setattr(cls, "__init__", mcs._make_init(cls.fields)) + cls.__init__ = mcs._make_init(cls.fields) # Check DataBase for file handlers - setattr( - cls, - "supports_file_operations", - cls.to_file != DataBase.to_file and cls.from_file != DataBase.from_file, + cls.supports_file_operations = ( + cls.to_file != DataBase.to_file and cls.from_file != DataBase.from_file ) @classmethod - def _make_property_getter(mcs, field, oneof_name=None): + def _make_property_getter( + mcs, field, oneof_name=None # pyright: ignore[reportSelfClsParameterName] + ): """This helper creates an @property attribute getter for the given field NOTE: This needs to live as a standalone function in order for the given @@ -340,7 +371,7 @@ def _property_getter(self): ) attr_val = backend.get_attribute(self.__class__, field) if isinstance(attr_val, self.__class__.OneofFieldVal): - log.debug2("Got a OneofFieldVal from the backend") + log.debug2("Got a OneofFieldVal from the backend") # type: ignore assert field in self.__class__._fields_oneofs_map self._get_which_oneof_dict()[field] = attr_val.which_oneof attr_val = attr_val.val @@ -435,11 +466,13 @@ def __init__(self, *args, **kwargs): setattr(self, field_name, None) # Set docstring to the method explicitly - setattr(__init__, "__doc__", docstring) + __init__.___doc__ = docstring return __init__ @classmethod - def _sorted_oneof_field_names(mcs, oneof: OneofDescriptor) -> List[str]: + def _sorted_oneof_field_names( + mcs, oneof: OneofDescriptor # pyright: ignore[reportSelfClsParameterName] + ) -> List[str]: """Helper to get the list of oneof fields while ensuring field names are sorted such that bool < int < float. This ensures that when iterating fields for which_oneof inference, lower-precedence types take @@ -529,7 +562,7 @@ def from_backend(cls, backend): return instance @property - def backend(self) -> Optional["DataModelBackendBase"]: + def backend(self) -> Optional["DataModelBackendBase"]: # type: ignore # noqa: F821 # see TYPE_CHECKING note at the top return getattr(self, _DataBaseMetaClass._BACKEND_ATTR, None) def which_oneof(self, oneof_name: str) -> Optional[str]: @@ -606,7 +639,7 @@ def _is_valid_type_for_field(cls, field_name: str, val: Any) -> bool: and field_descriptor.message_type == val.get_proto_class().DESCRIPTOR ) or ( isinstance(val, Enum) - and field_descriptor.enum_type == val.get_proto_class().DESCRIPTOR + and field_descriptor.enum_type == val.get_proto_class().DESCRIPTOR # type: ignore ): return True @@ -628,7 +661,7 @@ def _is_valid_type_for_field(cls, field_name: str, val: Any) -> bool: return isinstance(val, bytes) # If it's a primitive, use protobuf type checkers - checker = proto_type_checkers.GetTypeChecker(field_descriptor) + checker = proto_type_checkers.GetTypeChecker(field_descriptor) # type: ignore try: checker.CheckValue(val) return True @@ -730,7 +763,7 @@ def from_proto(cls, proto): "caikit.interfaces.common.data_model" ) ): - kwargs[oneof] = getattr(contained_obj, "values") + kwargs[oneof] = contained_obj.values # type: ignore else: kwargs[oneof] = contained_obj else: @@ -900,9 +933,9 @@ def fill_proto(self, proto): seq_dm = subproto.__class__ try: subproto.CopyFrom(seq_dm(values=attr)) - log.debug4("Successfully fill proto for %s", field) + log.debug4("Successfully fill proto for %s", field) # type: ignore except TypeError: - log.debug4("not the correct union list type") + log.debug4("not the correct union list type") # type: ignore else: attr.fill_proto(subproto) @@ -936,7 +969,7 @@ def to_dict(self) -> dict: fields_to_dict = [] for field in self.fields: if ( - not field in self._fields_to_oneof + field not in self._fields_to_oneof or self.which_oneof(self._fields_to_oneof[field]) == field ): fields_to_dict.append(field) @@ -970,7 +1003,7 @@ def to_json(self, **kwargs) -> str: """Convert to a json representation.""" def _default_serialization_overrides(obj): - """Default handler for nonserializable objects; currently this only handles + """Default handler for non-serializable objects; currently this only handles - bytes - datetime.datetime """ @@ -986,7 +1019,9 @@ def _default_serialization_overrides(obj): return json.dumps(self.to_dict(), **kwargs) - def to_file(self, file_obj: IOBase) -> Optional["File"]: + def to_file( + self, file_obj: IOBase + ) -> Optional["File"]: # type: ignore # noqa: F821 # see TYPE_CHECKING note at the top """Export a DataBaseObject into a file-like object `file_obj`. If the DataBase object has requirements around file name or file type it can return them via the optional "File" return object diff --git a/caikit/core/data_model/dataobject.py b/caikit/core/data_model/dataobject.py index 8f7937250..8aff57972 100644 --- a/caikit/core/data_model/dataobject.py +++ b/caikit/core/data_model/dataobject.py @@ -96,7 +96,7 @@ class will be the raw input representation of a @dataclass # Get the annotations that will go into the dataclass if name != "DataObjectBase": field_names = attrs.get("__annotations__") - parent_dataobjects = [ + parent_dataobjects: List[_DataBaseMetaClass] = [ base for base in bases if isinstance(base, _DataBaseMetaClass) ] field_name_sets = [base.fields for base in parent_dataobjects] @@ -129,7 +129,7 @@ class DataObjectBase(DataBase, metaclass=_DataObjectBaseMetaClass): """ -_DataObjectBaseT = TypeVar("_DataObjectBaseT", bound=DataObjectBase) +_DataObjectBaseT = TypeVar("_DataObjectBaseT", bound=Type[DataObjectBase]) def dataobject(*args, **kwargs) -> Callable[[_DataObjectBaseT], _DataObjectBaseT]: @@ -186,19 +186,21 @@ def decorator(cls: _DataObjectBaseT) -> _DataObjectBaseT: # If it's not an enum, fill in any missing field defaults as None # and make sure it's a dataclass if not issubclass(cls, Enum): - log.debug2("Wrapping data class %s", cls) + # alog needs a stub file or some method of typing the monkey-patched methods. + # Meanwhile, disable the type-checker for those calls. + log.debug2("Wrapping data class %s", cls) # type: ignore user_defined_defaults = {} for annotation in getattr(cls, "__annotations__", {}): user_defined_default = getattr(cls, annotation, dataclasses.MISSING) if user_defined_default == dataclasses.MISSING: - log.debug3("Filling in None default for %s.%s", cls, annotation) + log.debug3("Filling in None default for %s.%s", cls, annotation) # type: ignore setattr(cls, annotation, None) else: user_defined_defaults[annotation] = user_defined_default # If the current __init__ is auto-generated by dataclass, remove # it so that a new one is created with the new defaults if _has_dataclass_init(cls): - log.debug3("Resetting default dataclass init") + log.debug3("Resetting default dataclass init") # type: ignore delattr(cls, "__init__") cls = dataclasses.dataclass(repr=False)(cls) setattr(cls, _USER_DEFINED_DEFAULTS, user_defined_defaults) @@ -210,25 +212,27 @@ def decorator(cls: _DataObjectBaseT) -> _DataObjectBaseT: _AUTO_GEN_PROTO_CLASSES.append(proto_class) # Add enums to the global enums module - for enum_class in _get_all_enums(proto_class): - log.debug2("Importing enum [%s]", enum_class.DESCRIPTOR.name) + # (The type-checking gets too gnarly with google._upb._message.MessageMeta vs. + # the _message.Message class, so disable the type-checker for these right now.) + for enum_class in _get_all_enums(proto_class): # type: ignore + log.debug2("Importing enum [%s]", enum_class.DESCRIPTOR.name) # type: ignore enums.import_enum(enum_class) # Declare the merged class that binds DataBase to the wrapped class with # this generated proto class if not isinstance(proto_class, EnumTypeWrapper): - setattr(cls, "_proto_class", proto_class) + cls._proto_class = proto_class # type: ignore cls = _make_data_model_class(proto_class, cls) # If this was a default-generated dataclass __init__ and there are # any oneofs, we need to augment the __init__ to support kwargs for # the individual fields - if _has_dataclass_init(cls) and cls._fields_oneofs_map: - setattr(cls, "__init__", _make_oneof_init(cls)) + if _has_dataclass_init(cls) and cls._fields_oneofs_map: # type: ignore + cls.__init__ = _make_oneof_init(cls) else: - enums.import_enum(proto_class, cls) - setattr(cls, "_proto_enum", proto_class) + enums.import_enum(proto_class, cls) # type: ignore + cls._proto_enum = proto_class # type: ignore # Return the decorated class return cls diff --git a/caikit/core/data_model/enums.py b/caikit/core/data_model/enums.py index 34a9947ab..312ff9075 100644 --- a/caikit/core/data_model/enums.py +++ b/caikit/core/data_model/enums.py @@ -39,14 +39,9 @@ def to_dict(cls) -> Dict[str, int]: """Return a dict representation of the keys and values""" if not hasattr(cls, "__dict_repr__"): - setattr( - cls, - "__dict_repr__", - { - entry.name: entry.value - for entry in cls # pylint: disable=not-an-iterable - }, - ) + cls.__dict_repr__ = { + entry.name: entry.value for entry in cls # pylint: disable=not-an-iterable + } return cls.__dict_repr__ @@ -54,7 +49,7 @@ def to_dict(cls) -> Dict[str, int]: def to_munch(cls) -> munch.Munch: """Return a munchified version of the enum""" if not hasattr(cls, "__munch_repr__"): - setattr(cls, "__munch_repr__", munch.Munch(cls.to_dict())) + cls.__munch_repr__ = munch.Munch(cls.to_dict()) return cls.__munch_repr__ @@ -90,8 +85,8 @@ def import_enum( enum_class = Enum._create_(name, proto_enum.items()) # Add extra utility functions - setattr(enum_class, "to_dict", to_dict) - setattr(enum_class, "to_munch", to_munch) + enum_class.to_dict = to_dict + enum_class.to_munch = to_munch globals()[name] = enum_class rev_name = name + "Rev" diff --git a/caikit/core/data_model/json_dict.py b/caikit/core/data_model/json_dict.py index 20d224ea2..1b1a7742e 100644 --- a/caikit/core/data_model/json_dict.py +++ b/caikit/core/data_model/json_dict.py @@ -110,9 +110,7 @@ def _value_to_struct_value(value, struct_class, value_class, list_value_class): ) elif isinstance(value, bool): struct_value = value_class(bool_value=value) - elif isinstance(value, int): - struct_value = value_class(number_value=value) - elif isinstance(value, float): + elif isinstance(value, (int, float)): struct_value = value_class(number_value=value) elif isinstance(value, str): struct_value = value_class(string_value=value) diff --git a/caikit/core/data_model/protobufs/__init__.py b/caikit/core/data_model/protobufs/__init__.py index 930a305a8..6e644a8af 100644 --- a/caikit/core/data_model/protobufs/__init__.py +++ b/caikit/core/data_model/protobufs/__init__.py @@ -80,11 +80,11 @@ def import_protobufs(proto_dir, package_base_name, current_globals): all_enum_names = [] for module in all_modules: if module.DESCRIPTOR.package.startswith(_package_name): - for message_name in module.DESCRIPTOR.message_types_by_name.keys(): + for message_name in module.DESCRIPTOR.message_types_by_name: message_val = getattr(module, message_name) current_globals[message_name] = message_val globals()[message_name] = message_val - for enum_name in module.DESCRIPTOR.enum_types_by_name.keys(): + for enum_name in module.DESCRIPTOR.enum_types_by_name: enum_val = getattr(module, enum_name) current_globals[enum_name] = enum_val globals()[enum_name] = enum_val diff --git a/caikit/core/data_model/streams/csv_column_formatter.py b/caikit/core/data_model/streams/csv_column_formatter.py index 0d9ba73dd..55ff3fe09 100644 --- a/caikit/core/data_model/streams/csv_column_formatter.py +++ b/caikit/core/data_model/streams/csv_column_formatter.py @@ -19,6 +19,7 @@ """ # Standard from typing import Dict +import copy # First Party import alog @@ -92,26 +93,22 @@ def _convert(data_item): # Don't mutate a list that we're iterating on here # Subsequent re-entries into the stream would mutate the list further # and really mess things up - data_item_copy = list(data_item[:]) + data_item_copy = copy.deepcopy(data_item) - last_type = None for i, (element, type_) in enumerate( zip(data_item, self._expected_columns.values()) ): - last_type = type_ - if type_ == list: + if type_ is list: data_item_copy[i] = CSVColumnFormatter._attempt_to_listify( element ) - if len(data_item) > len(self._expected_columns): - # More data in the data item left... - if last_type == list: - # Last element was a list, so slurp the rest of the row in - length = len(self._expected_columns) + if len(data_item) > len(self._expected_columns) and (type_ is list): + # Last element was a list, so slurp the rest of the row in + length = len(self._expected_columns) - data_item_copy[length - 1].extend(data_item[length:]) - data_item_copy = data_item_copy[0:length] + data_item_copy[length - 1].extend(data_item[length:]) + data_item_copy = data_item_copy[0:length] return data_item_copy diff --git a/caikit/core/data_model/streams/data_stream.py b/caikit/core/data_model/streams/data_stream.py index deee46cbc..36db5f6fe 100644 --- a/caikit/core/data_model/streams/data_stream.py +++ b/caikit/core/data_model/streams/data_stream.py @@ -17,6 +17,7 @@ """ # Standard from collections.abc import Iterable +from functools import cached_property from glob import glob from io import UnsupportedOperation from typing import Dict, Generic, List, Tuple, TypeVar, Union @@ -47,6 +48,7 @@ T = TypeVar("T") + # ghart: These public methods are all needed. This class is essentially its own factory, so these # are all the different ways of coercing different data sources into a common stream class # pylint: disable=too-many-public-methods @@ -103,7 +105,6 @@ def __init__(self, generator_func, *args, **kwargs): self.generator_func = generator_func self.generator_args, self.generator_kwargs = args, kwargs - self._length = None @classmethod def from_iterable(cls, data: typing.Iterable[T]) -> "DataStream[T]": @@ -314,15 +315,14 @@ def from_csv(cls, filename: str, *args, skip=0, **kwargs) -> "DataStream[List]": @classmethod def _from_csv_generator(cls, filename, skip, *csv_args, **csv_kwargs): # open the csv file (closure around `filename`) - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: # skip lines if requested for _ in range(skip): # pylint: disable=stop-iteration-return next(fh) # for each line of the csv file, yield a list - for line in csv.reader(fh, *csv_args, **csv_kwargs): - yield line + yield from csv.reader(fh, *csv_args, **csv_kwargs) @classmethod def from_header_csv(cls, filename: str, *args, **kwargs) -> "DataStream[Dict]": @@ -368,7 +368,7 @@ def from_header_csv(cls, filename: str, *args, **kwargs) -> "DataStream[Dict]": @classmethod def _from_header_csv_generator(cls, filename, *csv_args, **csv_kwargs): # open the csv file (closure around `filename`) - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: yield from cls._from_header_csv_buffer_generator( fh, *csv_args, **csv_kwargs ) @@ -384,8 +384,7 @@ def _from_header_csv_buffer_generator(cls, fh: typing.IO, *csv_args, **csv_kwarg RuntimeError("File handler for csv with header not seekable"), ) # for each line of the csv file, yield a dict - for line in csv.DictReader(fh, *csv_args, **csv_kwargs): - yield line + yield from csv.DictReader(fh, *csv_args, **csv_kwargs) @classmethod def from_txt(cls, filename: str) -> "DataStream[str]": @@ -425,7 +424,7 @@ def from_txt(cls, filename: str) -> "DataStream[str]": @classmethod def _from_txt_generator(cls, filename): # open the file (closure around `filename`) - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: # for each line of the file for line in fh: # strip new lines and carriage returns and yield the line @@ -573,8 +572,7 @@ def _from_csv_collection_generator(cls, dirname): for filename in glob(os.path.join(dirname, "*.csv")): data_stream_list.append(cls.from_header_csv(filename=filename)) # yield the combined data item once flattened - for data_item in DataStream.chain(data_stream_list).flatten(): - yield data_item + yield from DataStream.chain(data_stream_list).flatten() @classmethod def from_jsonl_collection(cls, dirname: str) -> "DataStream[Dict]": @@ -602,8 +600,7 @@ def _from_jsonl_collection_generator(cls, dirname): for filename in glob(os.path.join(dirname, "*.jsonl")): data_stream_list.append(cls.from_jsonl(filename=filename)) # yield the combined data item once flattened - for data_item in DataStream.chain(data_stream_list).flatten(): - yield data_item + yield from DataStream.chain(data_stream_list).flatten() @classmethod def from_multipart_file(cls, filename: str) -> "DataStream[JsonDictValue]": @@ -806,8 +803,7 @@ def flatten(self) -> "DataStream": def generator_func(): for inner_stream in self: - for data_item in inner_stream: - yield data_item + yield from inner_stream return DataStream(generator_func) @@ -995,14 +991,20 @@ def __iter__(self): return generator def __len__(self): + """See property method self._length""" + return self._length + + @cached_property + def _length(self): """Return the number of data items contained in this data stream. This requires that the - data stream be iterated over, which may be time consuming. This value is then stored + data stream be iterated over, which may be time-consuming. This value is then stored internally so that subsequent calls do not iterate over the data stream again. - """ - if self._length is None: - self._length = sum(1 for data_item in self) - return self._length + This is implemented as a cached_property so that subclasses of DataStream which implement + their own __getstate__ and __setstate__ do not have to account for the existence of + self._length + """ + return sum(1 for _ in self) def __or__(self, module): """Feed this data stream into the `.stream` method of a module. This is syntactic sugar diff --git a/caikit/core/data_model/streams/multipart_decoder.py b/caikit/core/data_model/streams/multipart_decoder.py index 91ea03fef..79234c8ac 100644 --- a/caikit/core/data_model/streams/multipart_decoder.py +++ b/caikit/core/data_model/streams/multipart_decoder.py @@ -148,7 +148,7 @@ def _get_multipart_boundary(file) -> str: def _get_first_nonempty_line(file) -> str: """Return the first line of the file with content. Returns empty string if none exists.""" - with open(file, "r", encoding="utf-8") as fp: + with open(file, encoding="utf-8") as fp: for line in fp: stripped_line = line.strip() if stripped_line != "": diff --git a/caikit/core/data_model/streams/validator.py b/caikit/core/data_model/streams/validator.py index cfcaeef8c..fde5edfe7 100644 --- a/caikit/core/data_model/streams/validator.py +++ b/caikit/core/data_model/streams/validator.py @@ -114,9 +114,10 @@ def _validate_data(self, data_item: object, data_item_number: int) -> None: self._expected_keys.values(), data_item, range(len(data_item)) ): if not isinstance(element, type_): - # pylint: disable=too-many-format-args - message = "Expected element {} in data item to be of type {}, " - "but got {}".format(index, type_, type(element)) + message = ( + f"Expected element {index} in data item to be" + f"of type {type_}, but got {type(element)}" + ) raise DataValidationError(message, data_item_number) else: diff --git a/caikit/core/exceptions/caikit_core_exception.py b/caikit/core/exceptions/caikit_core_exception.py index 1af8d7054..c9ff72158 100644 --- a/caikit/core/exceptions/caikit_core_exception.py +++ b/caikit/core/exceptions/caikit_core_exception.py @@ -17,6 +17,7 @@ # Standard from enum import Enum +import uuid class CaikitCoreStatusCode(Enum): @@ -36,3 +37,4 @@ class CaikitCoreException(Exception): def __init__(self, status_code: CaikitCoreStatusCode, message: str) -> None: self.status_code = status_code self.message = message + self.id = uuid.uuid4().hex diff --git a/caikit/core/model_management/factories.py b/caikit/core/model_management/factories.py index c39cc5dba..24fd26a3e 100644 --- a/caikit/core/model_management/factories.py +++ b/caikit/core/model_management/factories.py @@ -21,6 +21,7 @@ from .local_model_initializer import LocalModelInitializer from .local_model_trainer import LocalModelTrainer from .multi_model_finder import MultiModelFinder +from .multi_model_initializer import MultiModelInitializer # Model trainer factory. A trainer is responsible for performing the train # operation against a configured framework connection. @@ -38,3 +39,4 @@ # location. model_initializer_factory = ImportableFactory("ModelInitializer") model_initializer_factory.register(LocalModelInitializer) +model_initializer_factory.register(MultiModelInitializer) diff --git a/caikit/core/model_management/local_model_trainer.py b/caikit/core/model_management/local_model_trainer.py index e1348c6b7..b6758766a 100644 --- a/caikit/core/model_management/local_model_trainer.py +++ b/caikit/core/model_management/local_model_trainer.py @@ -33,14 +33,14 @@ from ..data_model import TrainingStatus from ..exceptions import error_handler from ..modules import ModuleBase -from ..toolkit.destroyable_process import DestroyableProcess -from ..toolkit.destroyable_thread import DestroyableThread from ..toolkit.logging import configure as configure_logging from .model_trainer_base import ModelTrainerBase, TrainingInfo from caikit.core.exceptions.caikit_core_exception import ( CaikitCoreException, CaikitCoreStatusCode, ) +from caikit.core.toolkit.concurrency.destroyable_process import DestroyableProcess +from caikit.core.toolkit.concurrency.destroyable_thread import DestroyableThread import caikit log = alog.use_channel("LOC-TRNR") diff --git a/caikit/core/model_management/multi_model_initializer.py b/caikit/core/model_management/multi_model_initializer.py new file mode 100644 index 000000000..f5b5a0bb8 --- /dev/null +++ b/caikit/core/model_management/multi_model_initializer.py @@ -0,0 +1,146 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The MultiModelInitializer configures a set of other model initializers that will be used +in sequence to try loading models. + +Configuration for MultiModelInitializer lives under the config as follows: + +model_management: + initializers: + : + type: MULTI + config: + # Sequence of other initializer names to use in priority order + initializer_priority: + - other_initializer1 + - other_initializer2 +""" +# Standard +from typing import Optional + +# First Party +import aconfig +import alog + +# Local +from ...config import get_config +from ..exceptions import error_handler +from ..modules import ModuleBase, ModuleConfig +from .model_initializer_base import ModelInitializerBase + +# NOTE: Top-level import done so that global MODEL_MANAGER can be used at +# construction time without incurring a circular dependency +import caikit.core + +log = alog.use_channel("MINIT") +error = error_handler.get(log) + + +class MultiModelInitializer(ModelInitializerBase): + __doc__ = __doc__ + + name = "MULTI" + + def __init__(self, config: aconfig.Config, instance_name: str): + """Initialize with the sequence of initializers to use""" + self._instance_name = instance_name + initializer_priority = config.initializer_priority + error.type_check( + "", + list, + initializer_priority=initializer_priority, + ) + error.type_check_all( + "", + str, + initializer_priority=initializer_priority, + ) + error.value_check( + "", + initializer_priority, + "Must provide at least one valid initializer", + ) + config_initializers = get_config().model_management.initializers + invalid_initializers = [ + initializer + for initializer in initializer_priority + if initializer not in config_initializers + ] + error.value_check( + "", + not invalid_initializers, + "Invalid initializers given in initializer_priority: {}", + invalid_initializers, + ) + error.value_check( + "", + self._instance_name not in config_initializers, + "Cannot include self in multi initializer priority", + ) + model_manager = config.model_manager or caikit.core.MODEL_MANAGER + log.debug2( + "Setting up %s with initializer priority: %s", + self.name, + initializer_priority, + ) + self._initializers = [ + model_manager.get_initializer(initializer) + for initializer in initializer_priority + ] + + def init( + self, + model_config: ModuleConfig, + **kwargs, + ) -> Optional[ModuleBase]: + """Iterate through the sequence of initializers and return the first one that + succeeds + """ + for idx, initializer in enumerate(self._initializers): + log.debug2( + "Trying to init %s with initializer %d of type %s", + model_config.module_id, + idx, + initializer.name, + ) + try: + module = initializer.init(model_config, **kwargs) + if module: + log.debug( + "Init model %s with initializer %d of type %s", + model_config.module_id, + idx, + initializer.name, + ) + return module + log.debug2( + "Initializer %d of type %s unable to init %s", + idx, + initializer.name, + model_config.module_id, + ) + except Exception as err: # pylint: disable=broad-exception-caught + log.debug2( + "Initializer %d of type %s failed to load %s: %s", + idx, + initializer.name, + model_config.module_id, + err, + ) + log.debug4("Initializer error", exc_info=True) + + # No initializer succeeded + log.warning("Unable to init %s with any initializer", model_config.module_id) + return None diff --git a/caikit/core/model_manager.py b/caikit/core/model_manager.py index afc0a90c9..ae9bd9beb 100644 --- a/caikit/core/model_manager.py +++ b/caikit/core/model_manager.py @@ -249,11 +249,14 @@ def load( # This allows a user to load their own model (e.g. model saved to disk) load_path = get_config().load_path - if load_path is not None and isinstance(module_path, str): - if not os.path.exists(module_path): - full_module_path = os.path.join(load_path, module_path) - if os.path.exists(full_module_path): - module_path = full_module_path + if ( + load_path is not None + and isinstance(module_path, str) + and not os.path.exists(module_path) + ): + full_module_path = os.path.join(load_path, module_path) + if os.path.exists(full_module_path): + module_path = full_module_path # Ensure that we have a loadable directory. error.type_check("", str, BytesIO, bytes, module_path=module_path) diff --git a/caikit/core/modules/base.py b/caikit/core/modules/base.py index d96eb9e51..ae9f87974 100644 --- a/caikit/core/modules/base.py +++ b/caikit/core/modules/base.py @@ -48,6 +48,7 @@ log = alog.use_channel("MODULE") error = error_handler.get(log) + # pylint: disable=too-many-public-methods class ModuleBase(metaclass=_ModuleBaseMeta): """Abstract base class from which all modules should inherit.""" @@ -78,6 +79,33 @@ def metadata(self) -> Dict[str, Any]: self._metadata = {} return self._metadata + @property + def module_metadata(cls) -> Dict[str, Any]: + """Helper property to return metadata about a Module. This function + is separate from `metadata` as this is specific for the class module. This + function also requires a flat metadata structure without nested dictionaries. + + NOTE: This should be a @classmethod but using @property/@classmethod together has + been deprecated + + Returns: + Dict[str, str]: A dictionary of this ModuleBases's metadata + """ + + return {"name": cls.MODULE_NAME, "version": cls.MODULE_VERSION} + + @property + def public_model_info(cls) -> Dict[str, Any]: + """Helper property to return public metadata about a specific Model. This + function is separate from `metdata` as that contains the entire ModelConfig + which might not want to be shared/exposed. + + Returns: + Dict[str, str]: A dictionary of this models's public metadata + """ + + return {} + def set_load_backend(self, load_backend): """Method used by the model manager to indicate the load backend that was used to load this module @@ -89,8 +117,8 @@ def get_inference_signature( cls, input_streaming: bool, output_streaming: bool, - task: Type["caikit.core.TaskBase"] = None, - ) -> Optional["caikit.core.signature_parsing.CaikitMethodSignature"]: + task: Type["core.TaskBase"] = None, + ) -> Optional["core.signature_parsing.CaikitMethodSignature"]: """Returns the inference method signature that is capable of running the module's task for the given flavors of input and output streaming """ @@ -109,8 +137,8 @@ def get_inference_signature( @classmethod def get_inference_signatures( - cls, task: Type["caikit.core.TaskBase"] - ) -> List[Tuple[bool, bool, "caikit.core.signature_parsing.CaikitMethodSignature"]]: + cls, task: Type["core.TaskBase"] + ) -> List[Tuple[bool, bool, "core.signature_parsing.CaikitMethodSignature"]]: """Returns inference method signatures for all supported flavors of input and output streaming for a given task """ @@ -369,8 +397,7 @@ def timed_run(self, *args, num_seconds=None, num_iterations=None, **kwargs): time_passed = time.time() - start_time # stop on seconds or iterations depending on input arguments - # pylint: disable=unnecessary-lambda-assignment - continue_condition = ( + continue_condition = ( # noqa: E731 # lambda-assignment lambda t_p, i_p: t_p <= num_seconds if num_seconds else i_p < num_iterations ) response = None @@ -741,10 +768,7 @@ def _load_evaluation_dataset(dataset_path): return fileio.load_json(dataset_path) # if all else fails - error( - "", - ValueError("Unsure of how to load: {0}".format(dataset_path)), - ) + error("", ValueError(f"Unsure of how to load: {dataset_path}")) @staticmethod def _extract_gold_annotations(gold_set): diff --git a/caikit/core/modules/decorator.py b/caikit/core/modules/decorator.py index 58e864a44..ca2d50077 100644 --- a/caikit/core/modules/decorator.py +++ b/caikit/core/modules/decorator.py @@ -54,7 +54,7 @@ def module( base_module: Union[str, Type[ModuleBase]] = None, backend_config_override: Optional[Dict] = None, ): - f"""Apply this decorator to any class that should be treated as a caikit module + """Apply this decorator to any class that should be treated as a caikit module (i.e., extends`{caikit.core.ModuleBase}) and registered with caikit.core so that the library "knows" the class is a caikit module and is capable of loading instances of the module. @@ -227,7 +227,7 @@ def decorator(cls_): tasks_in_hierarchy = [] for class_ in cls_.mro(): - if hasattr(class_, "_TASK_CLASSES"): + if hasattr(class_, "_TASK_CLASSES") and class_ is not cls_: tasks_in_hierarchy.extend(class_._TASK_CLASSES) if tasks_in_hierarchy: @@ -240,7 +240,7 @@ def decorator(cls_): ) # Set its own backend_type as an attribute - setattr(cls_, "BACKEND_TYPE", backend_type) + cls_.BACKEND_TYPE = backend_type # Verify UUID and add this module to the module registry if not backend_module_impl: diff --git a/caikit/core/modules/meta.py b/caikit/core/modules/meta.py index a55de4dd6..424ed721f 100644 --- a/caikit/core/modules/meta.py +++ b/caikit/core/modules/meta.py @@ -66,7 +66,7 @@ def injected_load(*args): """ # Standard -from typing import Set +from typing import TYPE_CHECKING, Set import abc import functools @@ -77,6 +77,10 @@ def injected_load(*args): from ..exceptions import error_handler from .config import ModuleConfig +if TYPE_CHECKING: + # Local + from caikit.core import TaskBase + log = alog.use_channel("METADATA_INJECT") error = error_handler.get(log) @@ -154,7 +158,7 @@ def metadata_injecting_load(clz, *args, **kwargs): return super().__new__(mcs, name, bases, attrs) @property - def tasks(cls) -> Set["caikit.core.TaskBase"]: + def tasks(cls) -> Set["TaskBase"]: return set(cls._TASK_CLASSES) def __setattr__(cls, name, val): diff --git a/caikit/core/modules/saver.py b/caikit/core/modules/saver.py index 7007463bc..a6d70abe6 100644 --- a/caikit/core/modules/saver.py +++ b/caikit/core/modules/saver.py @@ -41,7 +41,8 @@ class ModuleSaver: """A module saver that provides common functionality used for saving modules and also a context - manager that cleans up gracefully in case an error is encountered during the save process. + manager that cleans up in case an error is encountered during the save process for a model_path + that did not already exist. """ SAVED_KEY_NAME = "saved" @@ -51,7 +52,7 @@ class ModuleSaver: MODULE_ID_KEY_NAME = "module_id" MODULE_CLASS_KEY_NAME = "module_class" - def __init__(self, module: ModuleBase, model_path): + def __init__(self, module: ModuleBase, model_path, exist_ok=True): """Construct a new module saver. Args: @@ -60,8 +61,10 @@ def __init__(self, module: ModuleBase, model_path): model_path (str): The absolute path to the directory where the model will be saved. If this directory does not exist, it will be created. + exist_ok (bool): Allow to overwrite existing model_path files. """ self.model_path = os.path.normpath(model_path) + self.exist_ok = exist_ok # Get possibly nested caikit library path module_path = module.__module__ @@ -313,20 +316,31 @@ def save_module_list(self, modules, config_key, **kwargs): def __enter__(self): """Enter the module saver context. This creates the `model_path` directory. If this context successfully exits, then the model configuration and all files it contains will - be written and saved to disk inside the `model_path` directory. If any uncaught exceptions - are thrown inside this context, then `model_path` will be removed. + be written and saved to disk inside the `model_path` directory. + + If `exist_ok` is False, an exception will be raised before touching existing `model_path` + files. + + If any uncaught exceptions are thrown inside this context, and `exist_ok` is False, + then this new `model_path` will be removed. If `exist_ok` is True, the files will be kept + and may include incomplete updates. """ - os.makedirs(self.model_path, exist_ok=True) + os.makedirs(self.model_path, exist_ok=self.exist_ok) return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit the module saver context. If this context successfully exits, then the model configuration and all files it contains will be written and saved to disk inside the - `model_path` directory. If any uncaught exceptions are thrown inside this context, then - `model_path` will be removed. + `model_path` directory. + + If any uncaught exceptions are thrown inside this context, and `exist_ok` is False, + then this new `model_path` will be removed. If `exist_ok` is True, the files will be kept + and may include incomplete updates. """ if exc_type is not None: - shutil.rmtree(self.model_path, ignore_errors=True) + if not self.exist_ok: + # Presume it is okay to rmtree + shutil.rmtree(self.model_path, ignore_errors=True) return ModuleConfig(self.config).save(self.model_path) diff --git a/caikit/core/registries.py b/caikit/core/registries.py index 97ff6b270..fc67a6fe0 100644 --- a/caikit/core/registries.py +++ b/caikit/core/registries.py @@ -19,7 +19,7 @@ """ # Standard -from typing import Any, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, Tuple, Type # First Party import alog @@ -27,7 +27,12 @@ # Local from .exceptions import error_handler +if TYPE_CHECKING: + # Local + from caikit.core import BackendBase, ModuleBase + log = alog.use_channel("REGISTRIES") + error = error_handler.get(log) @@ -38,7 +43,7 @@ MODULE_BACKEND_REGISTRY = {} -def module_registry() -> Dict[str, "caikit.core.ModuleBase"]: +def module_registry() -> Dict[str, "ModuleBase"]: """🌶️🌶️🌶️ This returns global state that should only be mutated if you know what you're doing! Returns the dictionary of decorated @modules that have been imported. @@ -53,9 +58,7 @@ def module_registry() -> Dict[str, "caikit.core.ModuleBase"]: return MODULE_REGISTRY -def module_backend_registry() -> Dict[ - str, Dict[str, Tuple["caikit.core.ModuleBase", Dict]] -]: +def module_backend_registry() -> Dict[str, Dict[str, Tuple["ModuleBase", Dict]]]: """🌶️🌶️🌶️ This returns global state that should only be mutated if you know what you're doing! Returns the module backend registry. This adds more nesting to the module registry, @@ -103,10 +106,10 @@ def module_backend_types() -> Dict[str, str]: return MODULE_BACKEND_TYPES -MODULE_BACKEND_CLASSES: Dict[str, Type["caikit.core.BackendBase"]] = {} +MODULE_BACKEND_CLASSES: Dict[str, Type["BackendBase"]] = {} -def module_backend_classes() -> Dict[str, Type["caikit.core.BackendBase"]]: +def module_backend_classes() -> Dict[str, Type["BackendBase"]]: """🌶️🌶️🌶️ This returns global state that should only be mutated if you know what you're doing! Returns the mapping of backend type name to concrete backend class diff --git a/caikit/core/task.py b/caikit/core/task.py index c28721e7a..4810d6394 100644 --- a/caikit/core/task.py +++ b/caikit/core/task.py @@ -239,14 +239,13 @@ def _raise_on_wrong_output_type(cls, output_type, module, output_streaming: bool return # Do some streaming checks - if output_streaming: - if cls._is_iterable_type(output_type): - # task_output_type is already guaranteed to be Iterable[T] - streaming_type = typing.get_args(task_output_type)[0] + if output_streaming and cls._is_iterable_type(output_type): + # task_output_type is already guaranteed to be Iterable[T] + streaming_type = typing.get_args(task_output_type)[0] - for iterable_type in typing.get_args(output_type): - if cls._subclass_check(iterable_type, streaming_type): - return + for iterable_type in typing.get_args(output_type): + if cls._subclass_check(iterable_type, streaming_type): + return raise TypeError( f"Wrong output type for module {module}: " diff --git a/caikit/core/toolkit/concurrency/__init__.py b/caikit/core/toolkit/concurrency/__init__.py new file mode 100644 index 000000000..6caa88338 --- /dev/null +++ b/caikit/core/toolkit/concurrency/__init__.py @@ -0,0 +1,17 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Contains tools for managing concurrent async workloads. +""" diff --git a/caikit/core/toolkit/destroyable.py b/caikit/core/toolkit/concurrency/destroyable.py similarity index 100% rename from caikit/core/toolkit/destroyable.py rename to caikit/core/toolkit/concurrency/destroyable.py diff --git a/caikit/core/toolkit/destroyable_process.py b/caikit/core/toolkit/concurrency/destroyable_process.py similarity index 96% rename from caikit/core/toolkit/destroyable_process.py rename to caikit/core/toolkit/concurrency/destroyable_process.py index 1249b22f6..8f4fbbde2 100644 --- a/caikit/core/toolkit/destroyable_process.py +++ b/caikit/core/toolkit/concurrency/destroyable_process.py @@ -31,8 +31,9 @@ import alog # Local -from ..exceptions import error_handler from .destroyable import Destroyable +from .pickling_exception import ExceptionPickler +from caikit.core.exceptions import error_handler log = alog.use_channel("DESTROY-PROC") error = error_handler.get(log) @@ -173,7 +174,9 @@ def run(self): # pragma: no cover err_str, exc_info=True, ) - self._child_conn.send(err) + # Wrap error for safe pickling + pickler = ExceptionPickler(err) + self._child_conn.send(pickler) finally: self._completion_event.set() @@ -187,6 +190,9 @@ def _update_result(self): def error(self) -> Optional[Exception]: self._update_result() + if isinstance(self.__result, ExceptionPickler): + return self.__result.get() + if isinstance(self.__result, Exception): return self.__result diff --git a/caikit/core/toolkit/destroyable_thread.py b/caikit/core/toolkit/concurrency/destroyable_thread.py similarity index 99% rename from caikit/core/toolkit/destroyable_thread.py rename to caikit/core/toolkit/concurrency/destroyable_thread.py index bc751a61b..2cc21ac8c 100644 --- a/caikit/core/toolkit/destroyable_thread.py +++ b/caikit/core/toolkit/concurrency/destroyable_thread.py @@ -60,7 +60,7 @@ def __init__( runnable_func, *runnable_args, work_done_event: Optional[threading.Event] = None, - **runnable_kwargs + **runnable_kwargs, ): threading.Thread.__init__(self) self.work_done_event = work_done_event or threading.Event() @@ -125,7 +125,7 @@ def run(self) -> None: *self.runnable_args, **self.runnable_kwargs ) self.__threw = False - except: # pylint: disable=bare-except + except: # noqa: E722 # bare-except # PEP8 complains, but in this case we really do want to re-throw _any_ exception that # occurred. In the interest of transparently wrapping any work in these threads, we # want to keep exception signatures identical. E.g. if I expect this thread to throw a diff --git a/caikit/core/toolkit/concurrency/pickling_exception.py b/caikit/core/toolkit/concurrency/pickling_exception.py new file mode 100644 index 000000000..93eb17de2 --- /dev/null +++ b/caikit/core/toolkit/concurrency/pickling_exception.py @@ -0,0 +1,168 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An ExceptionPickler deals with deconstructing any `Exception` type into picklable +parts so that it can be passed across a subprocess boundary without either failing +to un-pickle or losing important context. + +The python BaseException class implements its own __reduce__ method so that all +subclasses support pickling, but it has some intentional drawbacks: +1. It can't know about kwarg arguments to __init__, so it only supports subclasses +that have *arg initializers. However, many custom Exception types contain keyword +arguments in their __init__ methods which cause an unpickling failure. +2. It does not pickle the __cause__ or __context__ of an exception, presumably +because it can't make any guarantees about the picklability of those objects. This +leads to one-line tracebacks on the unpickled exception, because there is no context +to generate a useful stack trace from. +""" +# Standard +import pickle +import re + +# First Party +import alog + +# Local +from caikit.core.exceptions import error_handler + +log = alog.use_channel("EXC_PICKLER") +error = error_handler.get(log) + + +class PickleFailureFallbackException(Exception): + """Exception type used to replace exceptions that just cannot be pickled no matter + how hard we try.""" + + +class ExceptionPickler: + """Instances of this class can safely be pickled with any exception inside""" + + # Matches the specific TypeError that raises when exception classes allow init kwargs + # but do not handle them in __reduce__ + _type_error_expression = re.compile( + r".*__init__\(\) missing \d required positional argument" + ) + # Matches the names of the positional arguments that are missing, from the TypeError's string + _arg_match_expression = re.compile(r".*?'(.+?)'+") + + def __init__(self, exception: BaseException): + """ + Args: + exception: The exception that will be safely pickled within this container + """ + error.type_check( + "", Exception, allow_none=False, exception=exception + ) + self.exception = exception + + def get(self) -> BaseException: + """Returns the exception, reconstructed after pickling as best as possible""" + return self.exception + + def __setstate__(self, state_dict): + """Reconstructs the exception out of the state_dict that is returned by __getstate__""" + if "exception" in state_dict: + self.exception = state_dict["exception"] + else: + initializer = state_dict["initializer"] + self.exception = initializer(*state_dict["args"], **state_dict["kwargs"]) + + if state_dict["cause"]: + self.exception.__cause__ = state_dict["cause"].get() + + if state_dict["context"]: + self.exception.__context__ = state_dict["context"].get() + + def __getstate__(self) -> dict: + """Package up the exception's details into a dict, taking care to: + - include the __cause__ and __context__, which are not serialized by default + - Recursively wrap _those_ in PicklingExceptionWrappers + - Check that this exception _can_ be pickled, and try to handle common problems with + __reduce__ + """ + state_dict = { + "exception": self.exception, + "cause": ( + ExceptionPickler(self.exception.__cause__) + if self.exception.__cause__ + else None + ), + "context": ( + ExceptionPickler(self.exception.__context__) + if self.exception.__context__ + else None + ), + } + + # try/catch pickle errors + try: + pickle.loads(pickle.dumps(self.exception)) + log.debug4("Exception pickled successfully: %s", self.exception) + except TypeError as type_error: + log.debug4("Exception could not be pickled directly: %s", self.exception) + if self._type_error_expression.match(str(type_error)): + try: + keywords = self._arg_match_expression.findall(str(type_error)) + + log.debug4("Looking for keyword arguments: %s", keywords) + + # First grab the positional arguments. This should be provided by BaseException + args = self.exception.args + # Then look for each kwarg + kwargs = {} + for kwarg in keywords: + # Try to fetch the attribute + if hasattr(self.exception, kwarg): + arg = getattr(self.exception, kwarg) + kwargs[kwarg] = arg + elif hasattr(self.exception, f"_{kwarg}"): + arg = getattr(self.exception, f"_{kwarg}") + kwargs[kwarg] = arg + else: + raise ValueError( + f"{self.exception} has no attributes matching kwarg name {kwarg}" + ) + + state_dict.pop("exception") + state_dict["initializer"] = type(self.exception) + state_dict["args"] = args + state_dict["kwargs"] = kwargs + + # check that we can re-build this exception + _ = state_dict["initializer"]( + *state_dict["args"], **state_dict["kwargs"] + ) + log.debug4( + "Successfully found all the args to re-initialize exception" + ) + + except Exception as e: # pylint: disable=broad-exception-caught + log.debug4( + "Failed to find all args and kwargs to unpickle exception. Reason: %s", + e, + ) + + state_dict["exception"] = PickleFailureFallbackException( + str(self.exception) + ) + else: + log.debug4( + "Could not determine cause of pickling error: %s", type_error + ) + + state_dict["exception"] = PickleFailureFallbackException( + str(self.exception) + ) + + return state_dict diff --git a/caikit/core/toolkit/error_handler.py b/caikit/core/toolkit/error_handler.py index 64d93d0f4..d2bf9bd8a 100644 --- a/caikit/core/toolkit/error_handler.py +++ b/caikit/core/toolkit/error_handler.py @@ -26,7 +26,7 @@ # Allow DeprecationWarnings through if anything tries to import from `toolkit.errors` _warnings.filterwarnings("default", category=DeprecationWarning) # And actually warn them -_warnings.warn( +_warnings.warn( # noqa: B028 # no explicit stacklevel keyword argument "The caikit.toolkit.error_handler package has moved to caikit.core.exceptions", DeprecationWarning, ) diff --git a/caikit/core/toolkit/errors/__init__.py b/caikit/core/toolkit/errors/__init__.py index 58e5565b5..0fd68e9a6 100644 --- a/caikit/core/toolkit/errors/__init__.py +++ b/caikit/core/toolkit/errors/__init__.py @@ -26,7 +26,7 @@ # Allow DeprecationWarnings through if anything tries to import from `toolkit.errors` _warnings.filterwarnings("default", category=DeprecationWarning) # And actually warn them -_warnings.warn( +_warnings.warn( # noqa: B028 # no explicit stacklevel keyword argument "The caikit.toolkit.errors package has moved to caikit.core.exceptions", DeprecationWarning, ) diff --git a/caikit/core/toolkit/fileio.py b/caikit/core/toolkit/fileio.py index dcfaabdd2..3c43c83d0 100644 --- a/caikit/core/toolkit/fileio.py +++ b/caikit/core/toolkit/fileio.py @@ -29,13 +29,13 @@ def load_txt(filename): """Load a string from a file with utf8 encoding.""" - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: return fh.read() def load_txt_lines(filename): """Load a list of files from a text file with utf8 encoding""" - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: wordlist = list(map(str.strip, fh.readlines())) return wordlist @@ -60,7 +60,7 @@ def save_binary(data, filename): def load_csv(filename): """Load a csv into a list-of-lists.""" - with open(filename, mode="r", newline="", encoding="utf-8") as fh: + with open(filename, newline="", encoding="utf-8") as fh: return list(csv.reader(fh, delimiter=",", quotechar='"')) @@ -73,7 +73,7 @@ def save_csv(text_list, filename, mode="w"): def load_dict_csv(filename): """Load a csv into a list-of-dicts.""" - with open(filename, mode="r", encoding="utf-8") as csv_file: + with open(filename, encoding="utf-8") as csv_file: csv_reader = csv.DictReader(csv_file) return list(csv_reader) @@ -90,7 +90,7 @@ def save_dict_csv(dict_list, filename, mode="w"): def load_json(filename): """Load a json file into a dictionary.""" - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: return json.load(fh) @@ -102,7 +102,7 @@ def save_json(save_dict, filename, mode="w"): def load_yaml(filename): """Load a yaml file into a dictionary.""" - with open(filename, mode="r", encoding="utf8") as fh: + with open(filename, encoding="utf8") as fh: return yaml.safe_load(fh) diff --git a/caikit/core/toolkit/name_tools.py b/caikit/core/toolkit/name_tools.py new file mode 100644 index 000000000..999a5eb70 --- /dev/null +++ b/caikit/core/toolkit/name_tools.py @@ -0,0 +1,22 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common string functions that are generally helpful for generating runtime RPC names +and other Protobuf names +""" + + +def snake_to_upper_camel(string: str) -> str: + """Simple snake -> upper camel conversion for descriptors""" + return "".join([part[0].upper() + part[1:] for part in string.split("_")]) diff --git a/caikit/core/toolkit/quality_evaluation.py b/caikit/core/toolkit/quality_evaluation.py index f57006883..14ac3909a 100644 --- a/caikit/core/toolkit/quality_evaluation.py +++ b/caikit/core/toolkit/quality_evaluation.py @@ -126,7 +126,7 @@ def run( error( "", - ValueError("Unknown evaluation_type: {0}".format(evaluation_type)), + ValueError(f"Unknown evaluation_type: {evaluation_type}"), ) def singlelabel_multiclass_evaluation(self, labels=None) -> dict: diff --git a/caikit/core/toolkit/wip_decorator.py b/caikit/core/toolkit/wip_decorator.py index 58022ca04..5bad7ec58 100644 --- a/caikit/core/toolkit/wip_decorator.py +++ b/caikit/core/toolkit/wip_decorator.py @@ -129,10 +129,7 @@ def foo(*args, **kwargs): foo is still in the BETA phase and subject to change! """ - if args: - wrapped_obj = args[0] - else: - wrapped_obj = None + wrapped_obj = args[0] if args else None # Set defaults category = kwargs.get("category", WipCategory.WIP) diff --git a/caikit/interfaces/common/data_model/__init__.py b/caikit/interfaces/common/data_model/__init__.py index 188143db1..e65c6fae4 100644 --- a/caikit/interfaces/common/data_model/__init__.py +++ b/caikit/interfaces/common/data_model/__init__.py @@ -25,8 +25,15 @@ # Local # Import individual packages -from . import primitive_sequences, producer +from . import primitive_sequences, producer, vectors from .file import File from .primitive_sequences import BoolSequence, FloatSequence, IntSequence, StrSequence from .producer import ProducerId from .stream_sources import FileReference, ListOfFileReferences +from .vectors import ( + ListOfVector1D, + NpFloat32Sequence, + NpFloat64Sequence, + PyFloatSequence, + Vector1D, +) diff --git a/caikit/interfaces/common/data_model/vectors.py b/caikit/interfaces/common/data_model/vectors.py new file mode 100644 index 000000000..5189761d4 --- /dev/null +++ b/caikit/interfaces/common/data_model/vectors.py @@ -0,0 +1,198 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for embedding vector representations +""" +# Standard +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union +import json + +# Third Party +from google.protobuf import json_format +import numpy as np + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber +import alog + +# Local +from caikit.core import DataObjectBase, dataobject +from caikit.core.data_model import PACKAGE_COMMON +from caikit.core.exceptions import error_handler + +log = alog.use_channel("DATAM") +error = error_handler.get(log) + + +@dataobject(PACKAGE_COMMON) +@dataclass +class PyFloatSequence(DataObjectBase): + values: Annotated[List[float], FieldNumber(1)] = field(default_factory=list) + + +@dataobject(PACKAGE_COMMON) +@dataclass +class NpFloat32Sequence(DataObjectBase): + values: Annotated[List[np.float32], FieldNumber(1)] + + @classmethod + def from_proto(cls, proto): + values = np.asarray(proto.values, dtype=np.float32) + return cls(values) + + +@dataobject(PACKAGE_COMMON) +@dataclass +class NpFloat64Sequence(DataObjectBase): + values: Annotated[List[np.float64], FieldNumber(1)] + + @classmethod + def from_proto(cls, proto): + values = np.asarray(proto.values, dtype=np.float64) + return cls(values) + + +@dataobject(PACKAGE_COMMON) +@dataclass +class Vector1D(DataObjectBase): + """Data representation for a 1 dimension vector of float-type data.""" + + data: Annotated[ + Union[ + PyFloatSequence, + NpFloat32Sequence, + NpFloat64Sequence, + ], + FieldNumber(1), + ] + + def __post_init__(self): + error.value_check( + "", + hasattr(self.data, "values"), + ValueError("Vector1D requires a float sequence data object with values."), + ) + + @classmethod + def from_vector(cls, vector): + dtype = getattr(vector, "dtype", False) + if dtype is None: + data = PyFloatSequence(vector) + elif dtype == np.float32: + data = NpFloat32Sequence(vector) + elif dtype == np.float64: + data = NpFloat64Sequence(vector) + else: + data = PyFloatSequence(vector) + return cls(data=data) + + @classmethod + def from_json(cls, json_str: Union[Dict[str, Any], str]) -> "Vector1D": + """JSON does not have different float types. Move data into data_pyfloatsequence""" + + json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str + data = json_obj.pop("data") + if data is not None: + json_obj["data_pyfloatsequence"] = data + + json_str = json.dumps(json_obj) + try: + # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType + parsed_proto = json_format.Parse( + json_str, cls.get_proto_class()(), ignore_unknown_fields=False + ) + + # Use from_proto to return the DataBase object from the parsed proto + return cls.from_proto(parsed_proto) + + except json_format.ParseError as ex: + error("", ValueError(ex)) + + def to_dict(self) -> dict: + """to_dict is needed to make things serializable""" + values = self.data.values if self.data.values is not None else [] + return { + "data": { + # coerce numpy.ndarray and numpy.float32 into JSON serializable list of floats + "values": values.tolist() + if isinstance(values, np.ndarray) + else values + } + } + + @classmethod + def from_proto(cls, proto): + """Wrap the data in an appropriate float sequence, wrapped by this class""" + woo = proto.WhichOneof("data") + if woo is None: + return cls(PyFloatSequence()) + + woo_data = getattr(proto, woo) + if woo == "data_npfloat64sequence": + ret = cls(NpFloat64Sequence.from_proto(woo_data)) + elif woo == "data_npfloat32sequence": + ret = cls(NpFloat32Sequence.from_proto(woo_data)) + else: + ret = cls(PyFloatSequence.from_proto(woo_data)) + return ret + + def fill_proto(self, proto): + """Fill in the data in an appropriate data_""" + values = self.data.values + if values is not None and len(values) > 0: + sample = values[0] + error.type_check( + "", float, np.float32, np.float64, sample=sample + ) + if isinstance(sample, np.float64): + proto.data_npfloat64sequence.values.extend(values) + elif isinstance(sample, np.float32): + proto.data_npfloat32sequence.values.extend(values) + else: + proto.data_pyfloatsequence.values.extend(values) + + return proto + + +@dataobject(PACKAGE_COMMON) +class ListOfVector1D(DataObjectBase): + """Data representation for an embedding matrix holding 2D vectors""" + + vectors: Annotated[List[Vector1D], FieldNumber(1)] + + def __post_init__(self): + error.type_check("", list, vectors=self.vectors) + error.type_check_all("", Vector1D, vectors=self.vectors) + + @classmethod + def from_json(cls, json_str: Union[Dict[str, Any], str]) -> "ListOfVector1D": + """Fill in the vector data in an appropriate data_""" + + json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str + for v in json_obj["vectors"]: + data = v.pop("data") + if data is not None: + v["data_pyfloatsequence"] = data + json_str = json.dumps(json_obj) + try: + # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType + parsed_proto = json_format.Parse( + json_str, cls.get_proto_class()(), ignore_unknown_fields=False + ) + + # Use from_proto to return the DataBase object from the parsed proto + return cls.from_proto(parsed_proto) + + except json_format.ParseError as ex: + error("", ValueError(ex)) diff --git a/caikit/interfaces/nlp/data_model/__init__.py b/caikit/interfaces/nlp/data_model/__init__.py index ff1e6d715..909059549 100644 --- a/caikit/interfaces/nlp/data_model/__init__.py +++ b/caikit/interfaces/nlp/data_model/__init__.py @@ -13,7 +13,15 @@ # limitations under the License. # Local -from . import classification, package, text, text_generation +from . import ( + classification, + embedding_vectors, + package, + reranker, + sentence_similarity, + text, + text_generation, +) from .classification import ( ClassificationResult, ClassificationResults, @@ -24,7 +32,14 @@ TokenClassificationResults, TokenClassificationStreamResult, ) +from .embedding_vectors import EmbeddingResult, EmbeddingResults from .package import NLP_PACKAGE +from .reranker import RerankResult, RerankResults, RerankScore, RerankScores +from .sentence_similarity import ( + SentenceSimilarityResult, + SentenceSimilarityResults, + SentenceSimilarityScores, +) from .text import Token, TokenizationResults, TokenizationStreamResult from .text_generation import ( FinishReason, diff --git a/caikit/interfaces/nlp/data_model/classification.py b/caikit/interfaces/nlp/data_model/classification.py index 679d17a10..b1e49d44f 100644 --- a/caikit/interfaces/nlp/data_model/classification.py +++ b/caikit/interfaces/nlp/data_model/classification.py @@ -14,6 +14,7 @@ """Data structures for classification representations""" # Standard +from enum import Enum from typing import List, Optional # Third Party @@ -31,6 +32,21 @@ log = alog.use_channel("DATAM") +@dataobject(package=NLP_PACKAGE) +class InputWarningReason(Enum): + UNSUITABLE_INPUT = 0 + + +@dataobject(package=NLP_PACKAGE) +class InputWarning(DataObjectBase): + """Input Warning data object, which returns a reason and message associated with warnings + to issue to a user that causes errors (such as failed text generation) + """ + + id: Annotated[InputWarningReason, FieldNumber(1)] # id of input error + message: Annotated[str, FieldNumber(2)] # Error message detailing Warning + + @dataobject(package=NLP_PACKAGE) class ClassificationTrainRecord(DataObjectBase): """A classification training record consisting of a single train instance.""" @@ -100,6 +116,7 @@ class TokenClassificationStreamResult(TokenClassificationResults): processed_index: Annotated[ int, FieldNumber(2) ] # Result index up to which text is processed + start_index: Annotated[int, FieldNumber(3)] # Result start index for processed text @dataobject(package=NLP_PACKAGE) @@ -128,6 +145,9 @@ class TextGenTokenClassificationResults(DataObjectBase): Optional[np.uint64], FieldNumber(5) ] # The random seed used for text generation input_token_count: Annotated[Optional[int], FieldNumber(6)] + warnings: Annotated[ + Optional[List[InputWarning]], FieldNumber(9) + ] # Warning to user in the event of input errors @dataobject(package=NLP_PACKAGE) @@ -140,3 +160,4 @@ class ClassifiedGeneratedTextStreamResult(ClassifiedGeneratedTextResult): processed_index: Annotated[ Optional[int], FieldNumber(7) ] # Result index up to which text is processed + start_index: Annotated[int, FieldNumber(8)] # Result start index for processed text diff --git a/caikit/interfaces/nlp/data_model/embedding_vectors.py b/caikit/interfaces/nlp/data_model/embedding_vectors.py new file mode 100644 index 000000000..30163f142 --- /dev/null +++ b/caikit/interfaces/nlp/data_model/embedding_vectors.py @@ -0,0 +1,47 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for embedding vector representations +""" +# Standard +from dataclasses import dataclass + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber +import alog + +# Local +from ...common.data_model import ListOfVector1D, ProducerId, Vector1D +from caikit.core import DataObjectBase, dataobject +from caikit.core.exceptions import error_handler + +log = alog.use_channel("DATAM") +error = error_handler.get(log) + + +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass +class EmbeddingResult(DataObjectBase): + """Result from text embedding task""" + + result: Annotated[Vector1D, FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass +class EmbeddingResults(DataObjectBase): + """Results from text embeddings task""" + + results: Annotated[ListOfVector1D, FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] diff --git a/caikit/interfaces/nlp/data_model/reranker.py b/caikit/interfaces/nlp/data_model/reranker.py new file mode 100644 index 000000000..be700cafe --- /dev/null +++ b/caikit/interfaces/nlp/data_model/reranker.py @@ -0,0 +1,66 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List, Optional + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber + +# Local +from ...common.data_model import ProducerId +from caikit.core import DataObjectBase, dataobject +from caikit.core.data_model.json_dict import JsonDict + + +@dataobject(package="caikit_data_model.caikit_nlp") +class RerankScore(DataObjectBase): + """The score for one document (one query)""" + + document: Annotated[Optional[JsonDict], FieldNumber(1)] + index: Annotated[int, FieldNumber(2)] + score: Annotated[float, FieldNumber(3)] + text: Annotated[Optional[str], FieldNumber(4)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +class RerankScores(DataObjectBase): + """Scores for a query in a rerank task. + This is a list of n ReRankScore where n is based on top_n documents and each score indicates + the relevance of that document for this query. Results are ordered most-relevant first. + """ + + query: Annotated[Optional[str], FieldNumber(1)] + scores: Annotated[List[RerankScore], FieldNumber(2)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +class RerankResult(DataObjectBase): + """Result for one query in a rerank task. + This is a list of n ReRankScore where n is based on top_n documents and each score indicates + the relevance of that document for this query. Results are ordered most-relevant first. + """ + + result: Annotated[RerankScores, FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +class RerankResults(DataObjectBase): + """Results list for rerank tasks (supporting multiple queries). + For multiple queries, each one has a RerankQueryResult (ranking the documents for that query). + """ + + results: Annotated[List[RerankScores], FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] diff --git a/caikit/interfaces/nlp/data_model/sentence_similarity.py b/caikit/interfaces/nlp/data_model/sentence_similarity.py new file mode 100644 index 000000000..5c908a1dc --- /dev/null +++ b/caikit/interfaces/nlp/data_model/sentence_similarity.py @@ -0,0 +1,52 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for embedding vector representations +""" +# Standard +from typing import List + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber +import alog + +# Local +from ...common.data_model import ProducerId +from caikit.core import DataObjectBase, dataobject +from caikit.core.exceptions import error_handler + +log = alog.use_channel("DATAM") +error = error_handler.get(log) + + +@dataobject(package="caikit_data_model.caikit_nlp") +class SentenceSimilarityScores(DataObjectBase): + """Scores for a sentence similarity task""" + + scores: Annotated[List[float], FieldNumber(1)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +class SentenceSimilarityResult(DataObjectBase): + """Result for sentence similarity task""" + + result: Annotated[SentenceSimilarityScores, FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] + + +@dataobject(package="caikit_data_model.caikit_nlp") +class SentenceSimilarityResults(DataObjectBase): + """Results list for sentence similarity tasks""" + + results: Annotated[List[SentenceSimilarityScores], FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] diff --git a/caikit/interfaces/nlp/data_model/text.py b/caikit/interfaces/nlp/data_model/text.py index 87dd5f15b..b80d4bedf 100644 --- a/caikit/interfaces/nlp/data_model/text.py +++ b/caikit/interfaces/nlp/data_model/text.py @@ -55,3 +55,4 @@ class TokenizationStreamResult(TokenizationResults): processed_index: Annotated[ int, FieldNumber(2) ] # Result index up to which text is processed + start_index: Annotated[int, FieldNumber(3)] # Result start index for processed text diff --git a/caikit/interfaces/nlp/tasks.py b/caikit/interfaces/nlp/tasks.py index c8e16689b..73f5ee539 100644 --- a/caikit/interfaces/nlp/tasks.py +++ b/caikit/interfaces/nlp/tasks.py @@ -16,10 +16,12 @@ """ # Standard -from typing import Iterable +from typing import Iterable, List # Local from ...core import TaskBase, task +from ...core.data_model.json_dict import JsonDict +from .data_model import SentenceSimilarityResult, SentenceSimilarityResults from .data_model.classification import ( ClassificationResults, ClassifiedGeneratedTextResult, @@ -27,6 +29,8 @@ TokenClassificationResults, TokenClassificationStreamResult, ) +from .data_model.embedding_vectors import EmbeddingResult, EmbeddingResults +from .data_model.reranker import RerankResult, RerankResults from .data_model.text import TokenizationResults, TokenizationStreamResult from .data_model.text_generation import GeneratedTextResult, GeneratedTextStreamResult @@ -82,3 +86,85 @@ class ClassificationWithTextGenerationTask(TaskBase): input prompting text, generating additional text from that prompt and classifying the generated text based on detectors. """ + + +@task( + required_parameters={"text": str}, + output_type=EmbeddingResult, +) +class EmbeddingTask(TaskBase): + """Return a text embedding for the input text string""" + + +@task( + required_parameters={"texts": List[str]}, + output_type=EmbeddingResults, +) +class EmbeddingTasks(TaskBase): + """Return a text embedding for each text string in the input list""" + + +@task( + required_parameters={ + "documents": List[JsonDict], + "query": str, + }, + output_type=RerankResult, +) +class RerankTask(TaskBase): + """Returns an ordered list ranking the most relevant documents for the query + + Required parameters: + query: The search query + documents: JSON documents containing "text" or alternative "_text" to search + Returns: + The top_n documents in order of relevance (most relevant first). + For each, a score and document index (position in input) is returned. + The original document JSON is returned depending on optional args. + The top_n optional parameter limits the results when used. + """ + + +@task( + required_parameters={ + "documents": List[JsonDict], + "queries": List[str], + }, + output_type=RerankResults, +) +class RerankTasks(TaskBase): + """Returns an ordered list for each query ranking the most relevant documents for the query + + Required parameters: + queries: The search queries + documents: JSON documents containing "text" or alternative "_text" to search + Returns: + Results in order of the queries. + In each query result: + The query text is optionally included for visual convenience. + The top_n documents in order of relevance (most relevant first). + For each, a score and document index (position in input) is returned. + The original document JSON is returned depending on optional args. + The top_n optional parameter limits the results when used. + """ + + +@task( + required_parameters={"source_sentence": str, "sentences": List[str]}, + output_type=SentenceSimilarityResult, +) +class SentenceSimilarityTask(TaskBase): + """Compare the source_sentence to each of the sentences. + Result contains a list of scores in the order of the input sentences. + """ + + +@task( + required_parameters={"source_sentences": List[str], "sentences": List[str]}, + output_type=SentenceSimilarityResults, +) +class SentenceSimilarityTasks(TaskBase): + """Compare each of the source_sentences to each of the sentences. + Returns a list of results in the order of the source_sentences. + Each result contains a list of scores in the order of the input sentences. + """ diff --git a/caikit/interfaces/runtime/data_model/__init__.py b/caikit/interfaces/runtime/data_model/__init__.py index db2bf67da..88dd153c9 100644 --- a/caikit/interfaces/runtime/data_model/__init__.py +++ b/caikit/interfaces/runtime/data_model/__init__.py @@ -14,6 +14,13 @@ # Local from . import training_management +from .info import ( + ModelInfo, + ModelInfoRequest, + ModelInfoResponse, + RuntimeInfoRequest, + RuntimeInfoResponse, +) from .training_management import ( ModelPointer, TrainingInfoRequest, diff --git a/caikit/interfaces/runtime/data_model/info.py b/caikit/interfaces/runtime/data_model/info.py new file mode 100644 index 000000000..0e42b8f5a --- /dev/null +++ b/caikit/interfaces/runtime/data_model/info.py @@ -0,0 +1,70 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains interfaces to handle information requests +""" + +# Standard +from typing import Dict, List, Optional + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber +import alog + +# Local +from caikit.core.data_model import PACKAGE_COMMON, DataObjectBase, dataobject + +log = alog.use_channel("RUNTIMEOPS") + +RUNTIME_PACKAGE = f"{PACKAGE_COMMON}.runtime" + + +@dataobject(RUNTIME_PACKAGE) +class RuntimeInfoRequest(DataObjectBase): + """Empty request for runtime server information""" + + +@dataobject(RUNTIME_PACKAGE) +class RuntimeInfoResponse(DataObjectBase): + runtime_version: Annotated[Optional[str], FieldNumber(1)] + python_packages: Annotated[Dict[str, str], FieldNumber(2)] + + +@dataobject(RUNTIME_PACKAGE) +class ModelInfoRequest(DataObjectBase): + """Empty request for runtime server information""" + + model_ids: Annotated[Optional[List[str]], FieldNumber(1)] + + +@dataobject(RUNTIME_PACKAGE) +class ModelInfo(DataObjectBase): + """Information regarding a specific Model instance""" + + # Model information + model_path: Annotated[str, FieldNumber(1)] + name: Annotated[str, FieldNumber(2)] + size: Annotated[int, FieldNumber(3)] + metadata: Annotated[Dict[str, str], FieldNumber(4)] + + # Module Information + module_id: Annotated[str, FieldNumber(5)] + module_metadata: Annotated[Dict[str, str], FieldNumber(6)] + + +@dataobject(RUNTIME_PACKAGE) +class ModelInfoResponse(DataObjectBase): + """Model Info response contains a list of ModelInfos""" + + models: Annotated[List[ModelInfo], FieldNumber(1)] diff --git a/caikit/interfaces/ts/__init__.py b/caikit/interfaces/ts/__init__.py new file mode 100644 index 000000000..4d71e7d19 --- /dev/null +++ b/caikit/interfaces/ts/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from . import data_model, tasks diff --git a/caikit/interfaces/ts/data_model/__init__.py b/caikit/interfaces/ts/data_model/__init__.py new file mode 100644 index 000000000..09dcda5b2 --- /dev/null +++ b/caikit/interfaces/ts/data_model/__init__.py @@ -0,0 +1,36 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Data model definitions for structures in the time series domain +""" + +# ordering is important here to permit protobuf loading and dynamic +# `caikit.core` setup +# pylint: disable=wrong-import-order,wrong-import-position + +# Local +# Import the protobufs +from .package import TS_PACKAGE +from .time_types import ( + PeriodicTimeSequence, + PointTimeSequence, + Seconds, + TimeDuration, + TimePoint, + ValueSequence, +) + +from ._single_timeseries import SingleTimeSeries # isort:skip +from .timeseries import TimeSeries # isort:skip +from .timeseries_evaluation import Id, EvaluationRecord, EvaluationResult # isort:skip diff --git a/caikit/interfaces/ts/data_model/_single_timeseries.py b/caikit/interfaces/ts/data_model/_single_timeseries.py new file mode 100644 index 000000000..72e7cc471 --- /dev/null +++ b/caikit/interfaces/ts/data_model/_single_timeseries.py @@ -0,0 +1,410 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The core data model object for a TimeSeries +""" +# Standard +from datetime import timedelta +from typing import Iterable, List, Optional, Tuple, Union +import json + +# Third Party +import dateutil.parser +import numpy as np +import pandas as pd + +# First Party +from py_to_proto.dataclass_to_proto import ( # Annotated imported from here for compatibility + Annotated, + FieldNumber, + OneofField, +) +import alog + +# Local +from ....core import DataObjectBase +from ....core.data_model import dataobject +from ....core.exceptions import error_handler +from ..data_model.backends.util import strip_periodic +from .backends.base import TimeSeriesBackendBase +from .backends.pandas_backends import PandasTimeSeriesBackend +from .package import TS_PACKAGE +from .time_types import PeriodicTimeSequence, PointTimeSequence, Seconds, ValueSequence +from .toolkit.optional_dependencies import HAVE_PYSPARK, pyspark +from .toolkit.sparkconf import sparkconf_local + +log = alog.use_channel("TSDM") +error = error_handler.get(log) + +## TimeSeries ################################################################## + + +@dataobject(package=TS_PACKAGE) +class SingleTimeSeries(DataObjectBase): + """The TimeSeries object is the central data container for the library. + At present it wraps either a pandas.DataFrame, or pyspark.sql.DataFrame to bind + into the caikit data model. + """ + + @dataobject(package=TS_PACKAGE) + class StringIDSequence(DataObjectBase): + """Nested value sequence of strings""" + + values: Annotated[List[str], FieldNumber(1)] + + @dataobject(package=TS_PACKAGE) + class IntIDSequence(DataObjectBase): + """Nested value sequence of ints""" + + values: Annotated[List[int], FieldNumber(1)] + + time_sequence: Union[ + Annotated[PeriodicTimeSequence, OneofField("time_period"), FieldNumber(10)], + Annotated[PointTimeSequence, OneofField("time_points"), FieldNumber(20)], + ] + values: Annotated[List[ValueSequence], FieldNumber(1)] + timestamp_label: Annotated[str, FieldNumber(2)] + value_labels: Annotated[List[str], FieldNumber(3)] + ids: Union[ + Annotated[IntIDSequence, OneofField("id_int"), FieldNumber(30)], + Annotated[StringIDSequence, OneofField("id_str"), FieldNumber(40)], + ] + + _DEFAULT_TS_COL = "timestamp" + + # TODO: We need to clean up the init semantics + def __init__(self, *args, **kwargs): + """Constructing a TimeSeries directly always delegates to the pandas + backend + """ + # this is called from MultiTimeSeries + if backend := kwargs.get("_backend", None): + self._backend = backend + elif "values" in kwargs: + self._ids = None + if "id_int" in kwargs: + self._which_oneof_ids = "id_int" + self._ids = kwargs["id_int"] + if "id_str" in kwargs: + self._which_oneof_ids = "id_str" + self._ids = kwargs["id_str"] + if "time_period" in kwargs: + self._which_oneof_time_sequence = "time_period" + self._time_sequence = kwargs["time_period"] + if "time_points" in kwargs: + self._which_oneof_time_sequence = "time_points" + self._time_sequence = kwargs["time_points"] + + for k, v in kwargs.items(): + setattr(self, k, v) + + else: + error.value_check( + "", + len(args) != 0, + "must have at least the data argument", + args, + ) + data_arg = args[0] + + if isinstance(data_arg, pd.DataFrame): + self._backend = PandasTimeSeriesBackend(*args, **kwargs) + elif HAVE_PYSPARK and isinstance(data_arg, pyspark.sql.DataFrame): + # Local + # pylint: disable=import-outside-toplevel + from .backends._spark_backends import SparkTimeSeriesBackend + + self._backend = SparkTimeSeriesBackend(*args, **kwargs) + else: + raise NotImplementedError("not implemented yet") + + def _get_pd_df(self) -> Tuple[pd.DataFrame, str, Iterable[str]]: + """Convert the data to a pandas DataFrame, efficiently if possible""" + + # If there is a backend that knows how to do the conversion, use that + backend = getattr(self, "_backend", None) + if backend is not None and isinstance(backend, TimeSeriesBackendBase): + log.debug("Using backend pandas conversion") + return backend.as_pandas() + + # If not, convert the slow way from the proto representation + df_kwargs = {} + + # Since all fields are optional, we need to ensure that the + # time_sequence oneof has been set and that there are values + error.value_check( + "", + self.time_sequence is not None, + "Cannot create pandas data frame without a time sequence", + ) + error.value_check( + "", + self.values is not None, + "Cannot create pandas data frame without values", + ) + + # Determine the number of rows we'll expect + col_lens = {len(col.sequence.values) for col in self.values} + error.value_check( + "", + len(col_lens) == 1, + "Not all columns have matching lengths", + ) + num_rows = list(col_lens)[0] + log.debug("Num rows: %d", num_rows) + + # If the time index is stored periodically, this can be represented as a + # periodic index in pandas iff the start time and period are grounded in + # real datetime space. If they are purely numerical, they can be + # converted to a set of point values. The only invalid combination is a + # numeric start time and a timedelta duration. + # + # (datetime, numeric) -> period w/ numeric seconds + # (datetime, str) -> period w/ string freq + # (datetime, timedelta) -> period w/ timedelta freq + # (numeric, numeric) -> point sequence + # (numeric, [str, timedelta]) -> INVALID + if self.time_period is not None: + start_time = self.time_period.start_time + period_length = self.time_period.period_length + error.value_check( + "", + start_time.time is not None, + "start_time must be set in time_period", + ) + error.value_check( + "", + period_length.time is not None, + "period_length must be set in time_period", + ) + + numeric_start_time = start_time.ts_epoch is None + numeric_period = period_length.dt_str is None and ( + period_length.dt_int is not None or period_length.dt_float is not None + ) + error.value_check( + "", + not (numeric_start_time and not numeric_period), + "Time period cannot have a numeric start_time with a timedelta period_length", + ) + + if numeric_start_time: + df_kwargs["index"] = pd.RangeIndex( + start=start_time.time, + stop=period_length.time * num_rows, + step=period_length.time, + ) + elif numeric_period: + df_kwargs["index"] = pd.period_range( + start_time.ts_epoch.as_datetime(), + freq=timedelta(seconds=period_length.time), + periods=num_rows, + ) + else: + df_kwargs["index"] = pd.period_range( + start_time.ts_epoch.as_datetime(), + freq=period_length.dt_str, + periods=num_rows, + ) + + # Otherwise, interpret the sequence of time values directly + else: + time_points = self.time_points.points + error.value_check( + "", + time_points is not None and len(time_points) == num_rows, + "Number of time points {} doesn't match number of rows {}", + -1 if time_points is None else len(time_points), + num_rows, + ) + if time_points: + # Convert to a sequence of contiguous points + time_point_values = [tp.time for tp in time_points] + time_point_type = type(time_point_values[0]) + error.type_check_all( + "", + time_point_type, + time_point_values=time_point_values, + ) + + # If the type needs conversion to datetimes, do so + if time_point_type == Seconds: + time_point_values = [val.as_datetime() for val in time_point_values] + + df_kwargs["index"] = time_point_values + + # Make the columns dict + value_labels = self.value_labels or range(len(self.values)) + error.value_check( + "", + len(value_labels) == len(self.values), + "Wrong number of value labels {} for {} value columns", + len(value_labels), + len(self.values), + ) + + def deserialize_values_if_necessary(sequence_values): + if isinstance(sequence_values, ValueSequence.TimePointSequence): + return [dateutil.parser.parse(v) for v in sequence_values.values] + if isinstance(sequence_values, ValueSequence.AnyValueSequence): + return [json.loads(v) for v in sequence_values.values] + if isinstance(sequence_values, ValueSequence.VectorValueSequence): + # this is required as the underlying type is just a repeated scalar field, we need + # it to be a list + return [list(v) for v in sequence_values.values] + return sequence_values.values + + df_kwargs["data"] = dict( + zip( + value_labels, + (deserialize_values_if_necessary(col.sequence) for col in self.values), + ) + ) + + result_df = pd.DataFrame(**df_kwargs) + if self.timestamp_label != "": + result_df.reset_index(inplace=True) + result_df = result_df.rename(columns={"index": self.timestamp_label}) + + # Not exposing the _single_timeseries and the dataframe will be cached elsewhere + # self._backend = PandasTimeSeriesBackend(result_df, timestamp_column=self.timestamp_label, + # value_columns=value_labels) + # Make the data frame + return result_df, self.timestamp_label, value_labels + + def __len__(self) -> int: + """Return the length of the single time series object. + + Returns: + int: Length + """ + if self.values: + return len(self.values[0].sequence.values) + return 0 + + def __eq__(self, other: "SingleTimeSeries") -> bool: + """Equivalence operator for SingleTimeSeries objects. + + Performs ordering of data based on timestamp_label prior to checking for equivalence. Relies + on underlying pandas equivalence testing function `pd.testing.assert_frame_equal`. + + Args: + other (SingleTimeSeries): SingleTimeSeries to test against. + + Returns: + bool: True if the SingleTimeSeries are equivalent. + """ + + error.type_check("", SingleTimeSeries, other=other) + + if self.timestamp_label != other.timestamp_label: + return False + + sort_columns = [self.timestamp_label] if self.timestamp_label else [] + + try: + pd.testing.assert_frame_equal( + self.as_pandas().sort_values(by=sort_columns), + other.as_pandas().sort_values(by=sort_columns), + ) + except AssertionError: + return False + + return True + + ## Views ## + + def _as_pandas_ops(self, adf, include_timestamps: Union[None, bool] = False): + """operate on pandas-like object instead of strictly pandas""" + backend_df = adf + + # if we want to include timestamps, but it is not already in the dataframe, we need to add + if include_timestamps and self.timestamp_label is None: + dftouse = backend_df.copy(deep=False) # this does seem to be necessary + dftouse[self.__class__._DEFAULT_TS_COL] = ( + list(range(len(dftouse))) + if isinstance(dftouse, pyspark.pandas.DataFrame) + else np.arange(len(dftouse)) + ) + return dftouse + # if we do not want timestamps, but we already have them in the dataframe, we need to return + # a view without timestamps + if ( + include_timestamps is not None and not include_timestamps + ) and self.timestamp_label is not None: + return backend_df.loc[:, backend_df.columns != self.timestamp_label] + + return backend_df + + def as_pandas(self, include_timestamps: Optional[bool] = None) -> "pd.DataFrame": + """Get the view of this timeseries as a pandas DataFrame + + Args: + include_timestamps (bool, optional): Control the addition or removal of + timestamps. True will include timestamps, generating if needed, while False will + remove timestamps. Use None to returned what is available, leaving unchanged. + Defaults to None. + + Returns: + pd.DataFrame: The view of the data as a pandas DataFrame + """ + backend_df = self._get_pd_df()[0] + return self._as_pandas_ops( + adf=backend_df, include_timestamps=include_timestamps + ) + + def as_spark( + self, include_timestamps: Optional[bool] = None + ) -> "pyspark.sql.DataFrame": + """Get the view of this timeseries as a spark DataFrame + + Args: + include_timestamps (bool, optional): Control the addition or removal of + timestamps. True will include timestamps, generating if needed, while False will + remove timestamps. Use None to returned what is available, leaving unchanged. + Defaults to None. + + Returns: + pyspark.sql.DataFrame: The view of the data as a spark DataFrame + """ + if not HAVE_PYSPARK: + raise NotImplementedError( + "You must have pyspark installed for this to work!" + ) + + # Third Party + # pylint: disable=import-outside-toplevel + from pyspark.pandas import DataFrame as psdataframe + from pyspark.sql import SparkSession + + # Local + # pylint: disable=import-outside-toplevel + from .backends._spark_backends import SparkTimeSeriesBackend + + # If there is a backend that knows how to do the conversion, use that + backend = getattr(self, "_backend", None) + if backend is not None and isinstance(backend, SparkTimeSeriesBackend): + backend_df = backend._pyspark_df + pandas_like: psdataframe = backend_df.pandas_api() + timeseries_magic = self._as_pandas_ops( + pandas_like, include_timestamps=include_timestamps + ) + return timeseries_magic.to_spark() + + spark = SparkSession.builder.config(conf=sparkconf_local()).getOrCreate() + return spark.createDataFrame( + strip_periodic( + input_df=self.as_pandas(include_timestamps=include_timestamps) + ) + ) diff --git a/caikit/interfaces/ts/data_model/backends/__init__.py b/caikit/interfaces/ts/data_model/backends/__init__.py new file mode 100644 index 000000000..2068258bf --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/caikit/interfaces/ts/data_model/backends/_spark_backends.py b/caikit/interfaces/ts/data_model/backends/_spark_backends.py new file mode 100644 index 000000000..d896a2e5c --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/_spark_backends.py @@ -0,0 +1,233 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Core data model backends backed by pyspark.sql.DataFrame. + +This module is not intended for direct importing. It's used +by the caikit ts datamodel. Directly importing this module +will force a hard spark dependency which we do not want +to do. +""" + +# Standard +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple, Type, Union + +# Third Party +import pandas as pd + +# this import is ok because this module is NOT proactively imported +import pyspark + +# First Party +import alog + +# Local +from .....core.data_model import ProducerId +from .....core.exceptions import error_handler +from .._single_timeseries import SingleTimeSeries +from .base import MultiTimeSeriesBackendBase, TimeSeriesBackendBase +from .pandas_backends import PandasMultiTimeSeriesBackend, PandasTimeSeriesBackend +from .spark_util import mock_pd_groupby + +if TYPE_CHECKING: + # Local + from ..timeseries import TimeSeries + +log = alog.use_channel("SPBCK") +error = error_handler.get(log) + + +@contextmanager +def ensure_spark_cached(dataframe: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: + """Will ensure that a given dataframe is cached. + If dataframe is already cached it does nothing. If it's not + cached, it will cache it and then uncache the object when + the ensure_spark_cached object container goes out of scope. Users + must utilize the with pattern of access. + + Example: + ```python + with ensure_spark_cached(df) as _: + # do dataframey sorts of things on df + # it's guarenteed to be cached + # inside this block + # that's it, you're done. + # df remains cached if it already was + # or it's no longer cached if it wasn't + # before entering the with block above. + ``` + """ + do_cache = hasattr(dataframe, "cache") and not dataframe.is_cached + if do_cache: + dataframe.cache() + yield dataframe + if do_cache: + dataframe.unpersist() + + +class SparkMultiTimeSeriesBackend(MultiTimeSeriesBackendBase): + def __init__( + self, + data_frame: pyspark.sql.DataFrame, + key_column: Union[Iterable[str], str], + timestamp_column: str = None, + value_columns: Optional[Iterable[str]] = None, + ids: Optional[Union[Iterable[int], Iterable[str]]] = None, + producer_id: Optional[Union[Tuple[str, str], ProducerId]] = None, + ): + error.type_check("", pyspark.sql.DataFrame, data_frame=data_frame) + + # for param validation + pd_mts = PandasMultiTimeSeriesBackend( + data_frame=pd.DataFrame(columns=data_frame.columns), + key_column=key_column, + timestamp_column=timestamp_column, + value_columns=value_columns, + ids=ids, + producer_id=producer_id, + ) + + self._pyspark_df: pyspark.sql.DataFrame = data_frame + # for tapping into pandas api call when needed + self._pyspark_pandas_df = self._pyspark_df.pandas_api() + self._key_column = key_column + self._timestamp_column = timestamp_column + # pylint: disable=duplicate-code + self._value_columns = value_columns or [ + col + for col in data_frame.columns + if col != timestamp_column and col not in key_column + ] + self._ids = [] if ids is None else ids + self._producer_id = ( + producer_id + if isinstance(producer_id, ProducerId) + else (ProducerId(*producer_id) if producer_id is not None else None) + ) + self._key_columns = pd_mts._key_columns + + def get_attribute(self, data_model_class: Type["TimeSeries"], name: str) -> Any: + if name == "timeseries": + result = [] + + if len(self._key_columns) == 0: + with ensure_spark_cached(self._pyspark_df) as _: + backend = SparkTimeSeriesBackend( + data_frame=self._pyspark_df, + timestamp_column=self._timestamp_column, + value_columns=self._value_columns, + ) + result.append(SingleTimeSeries(_backend=backend)) + else: + with ensure_spark_cached(self._pyspark_df) as _: + for ids, spark_df in mock_pd_groupby( + self._pyspark_df, by=self._key_columns + ): + k = ids + if isinstance(k, (str, int)): + k = [k] + backend = SparkTimeSeriesBackend( + data_frame=spark_df, + timestamp_column=self._timestamp_column, + value_columns=self._value_columns, + ids=k, + ) + result.append(SingleTimeSeries(_backend=backend)) + return result + + if name == "id_labels": + return self._key_columns + + # If requesting producer_id or ids, just return the stored value + if name == "producer_id": + return self._producer_id + + raise ValueError(f"Provided an attribute name that does not exist - {name}") + + def as_pandas(self) -> Tuple[pd.DataFrame, Iterable[str], str, Iterable[str]]: + return ( + self._pyspark_df.toPandas(), + self._key_column, + self._timestamp_column, + self._value_columns, + ) + + +class SparkTimeSeriesBackend(TimeSeriesBackendBase): + """The SparkTimeSeries is responsible for managing the standard + in-memory representation of a TimeSeries using a spark backend compute engine. + """ + + def __init__( + self, + data_frame: pyspark.sql.DataFrame, + timestamp_column: Optional[str] = None, + value_columns: Optional[Iterable[str]] = None, + ids: Optional[Iterable[int]] = None, + ): + """At init time, hold onto the data frame as well as the arguments that + tell where the time and values live + + Args: + data_frame: pyspark.sql.DataFrame + The raw data frame + timestamp_column: Optional[str] + The name of the column holding the timestamps. If set to None, timestamps will be + assigned based on the rows index (default is None) + value_columns: Optional[Iterable[str]] + A sequence of names of columns to hold as values + ids: Optional[iterable[int]] + A sequence of numeric IDs associated with this TimeSeries + """ + + # Validators special to this class + error.type_check("", pyspark.sql.DataFrame, data_frame=data_frame) + + self._pyspark_df: pyspark.sql.DataFrame = data_frame + + # for tapping into pandas api call when needed + self._pyspark_pandas_df = self._pyspark_df.pandas_api() + + # this will give us basic parameter validation + self._pdbackend_helper = PandasTimeSeriesBackend( + data_frame=pd.DataFrame(columns=data_frame.columns), + value_columns=value_columns, + timestamp_column=str(timestamp_column) + if timestamp_column is not None + else timestamp_column, + ids=ids, + ) + + def get_attribute( + self, data_model_class: Type["SingleTimeSeries"], name: str + ) -> Any: + """When fetching a data attribute from the timeseries, this aliases to + the appropriate set of backend wrappers for the various fields. + """ + + with ensure_spark_cached(self._pyspark_df) as _: + return self._pdbackend_helper.get_attribute( + data_model_class=data_model_class, + name=name, + external_df=self._pyspark_pandas_df, + ) + + def as_pandas(self) -> Tuple[pd.DataFrame, str, Iterable[str]]: + return ( + self._pyspark_df.toPandas(), + self._pdbackend_helper._timestamp_column, + self._pdbackend_helper._value_columns, + ) diff --git a/caikit/interfaces/ts/data_model/backends/base.py b/caikit/interfaces/ts/data_model/backends/base.py new file mode 100644 index 000000000..c7297c9ce --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/base.py @@ -0,0 +1,103 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base classes to share between data model backends +""" + +# Standard +from typing import Any, Iterable, Tuple, Type +import abc + +# Third Party +import pandas as pd + +# First Party +import alog + +# Local +from caikit.core.data_model.data_backends import DataModelBackendBase +from caikit.core.exceptions import error_handler + +log = alog.use_channel("DMBCK") +error = error_handler.get(log) + + +class UncachedBackendMixin(DataModelBackendBase): + """Intermediate base class that disables attribute caching""" + + def cache_attribute(self, *_, **__) -> bool: + """Never cache attributes""" + return False + + +class StrictFieldBackendMixin(DataModelBackendBase): + """Intermediate base class that raises attribute errors for unknown fields""" + + def get_attribute(self, data_model_class: Type, name: str) -> Any: + """Base implementation that raises an AttributeError on bad attr names. + It should be called after object-specific logic. + """ + if name not in data_model_class.fields: + error( + "", + AttributeError( + f"No such attribute [{name}] on [{data_model_class.__name__}]" + ), + ) + + +class TimeSeriesBackendBase(UncachedBackendMixin, StrictFieldBackendMixin): + """Abstract base class for all backends of the central TimeSeries data model + type + """ + + @abc.abstractmethod + def as_pandas(self) -> Tuple[pd.DataFrame, str, Iterable[str]]: + """All backends must implement the ability to coerce their underlying + data into a pandas DataFrame and provide the pointers to the timeseries + source and value source(s) + + Returns: + df: pd.DataFrame + The data frame itself + timestamp_source: str + The column name (or None) indicating where the + timestamp sequence can be found + value_source: Iterable[str] + The names of the columns holding value sequences + """ + + +class MultiTimeSeriesBackendBase(UncachedBackendMixin, StrictFieldBackendMixin): + """Abstract base class for all backends of the central MultiTimeSeries data model + type + """ + + @abc.abstractmethod + def as_pandas(self) -> Tuple[pd.DataFrame, Iterable[str], str, Iterable[str]]: + """All backends must implement the ability to coerce their underlying + data into a pandas DataFrame and provide the pointers to the timeseries + source and value source(s) + + Returns: + df: pd.DataFrame + The data frame itself + key_source: Iterable[str] + the names of the columns holding key values + timestamp_source: str + The column name (or None) indicating where the + timestamp sequence can be found + value_source: Iterable[str] + The names of the columns holding value sequences + """ diff --git a/caikit/interfaces/ts/data_model/backends/dfcache.py b/caikit/interfaces/ts/data_model/backends/dfcache.py new file mode 100644 index 000000000..64e2076c2 --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/dfcache.py @@ -0,0 +1,48 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to manageing spark DataFrame caching""" + +# Standard +from contextlib import contextmanager + +# Third Party +from pyspark.sql import DataFrame + + +@contextmanager +def ensure_spark_cached(dataframe: DataFrame) -> DataFrame: + """Will ensure that a given dataframe is cached. + If dataframe is already cached it does nothing. If it's not + cached, it will cache it and then uncache the object when + the ensure_spark_cached object container goes out of scope. Users + must utilize the with pattern of access. + + Example: + ```python + with ensure_spark_cached(df) as _: + # do dataframey sorts of things on df + # it's guarenteed to be cached + # inside this block + # that's it, you're done. + # df remains cached if it already was + # or it's no longer cached if it wasn't + # before entering the with block above. + ``` + """ + do_cache = hasattr(dataframe, "cache") and not dataframe.is_cached + if do_cache: + dataframe.cache() + yield dataframe + if do_cache: + dataframe.unpersist() diff --git a/caikit/interfaces/ts/data_model/backends/pandas_backends.py b/caikit/interfaces/ts/data_model/backends/pandas_backends.py new file mode 100644 index 000000000..dff5a42f3 --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/pandas_backends.py @@ -0,0 +1,516 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Core data model backends backed by pandas +""" + +# Standard +from datetime import datetime +from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple, Type, Union +import json + +# Third Party +from pandas import RangeIndex +import numpy as np +import pandas as pd + +# First Party +import alog + +# Local +from .....core.data_model import DataBase, ProducerId +from .....core.exceptions import error_handler +from ... import data_model as dm +from .. import time_types +from ..toolkit.optional_dependencies import HAVE_PYSPARK +from .base import ( + MultiTimeSeriesBackendBase, + StrictFieldBackendMixin, + TimeSeriesBackendBase, + UncachedBackendMixin, +) +from .spark_util import iteritems_workaround +from .util import pd_timestamp_to_seconds + +if TYPE_CHECKING: + # Local + from ..timeseries import TimeSeries + +log = alog.use_channel("PDBCK") +error = error_handler.get(log) + + +class PandasMultiTimeSeriesBackend(MultiTimeSeriesBackendBase): + def as_pandas(self) -> Tuple[pd.DataFrame, Iterable[str], str, Iterable[str]]: + return self._df, self._key_column, self._timestamp_column, self._value_columns + + def __init__( + self, + data_frame: pd.DataFrame, + key_column: Union[Iterable[str], str], + timestamp_column: Optional[str] = None, + value_columns: Optional[Iterable[str]] = None, + ids: Optional[Union[Iterable[int], Iterable[str]]] = None, + producer_id: Optional[Union[Tuple[str, str], ProducerId]] = None, + ): + error.type_check("", pd.DataFrame, data_frame=data_frame) + error.type_check( + "", + list, + str, + key_column=key_column, + ) + error.type_check( + "", str, int, type(None), timestamp_column=timestamp_column + ) + error.type_check_all( + "", + str, + int, + allow_none=True, + value_columns=value_columns, + ) + error.type_check_all( + "", + str, + allow_none=True, + ids=ids, + ) + error.type_check( + "", + tuple, + ProducerId, + allow_none=True, + producer_id=producer_id, + ) + + # Validate the column names + error.value_check( + "", + (timestamp_column is None or (timestamp_column in data_frame.columns)), + "Invalid timestamp column/index: {}", + timestamp_column, + ) + + self._df = data_frame + self._key_column = key_column + self._timestamp_column = timestamp_column + key_column_list = [key_column] if isinstance(key_column, str) else key_column + # pylint: disable=duplicate-code + self._value_columns = value_columns or [ + col + for col in data_frame.columns + if col != timestamp_column and col not in key_column_list + ] + self._ids = [] if ids is None else ids + self._producer_id = ( + producer_id + if isinstance(producer_id, ProducerId) + else (ProducerId(*producer_id) if producer_id is not None else None) + ) + self._timeseries = None + self._key_columns = ( + [self._key_column] + if isinstance(self._key_column, str) + else self._key_column + ) + + def get_attribute(self, data_model_class: Type["TimeSeries"], name: str) -> Any: + if name == "timeseries": + result = [] + + if len(self._key_columns) == 0: + backend = PandasTimeSeriesBackend( + self._df, + timestamp_column=self._timestamp_column, + value_columns=self._value_columns, + ) + result.append(dm.SingleTimeSeries(_backend=backend)) + else: + for k, k_df in self._df.groupby( + self._key_columns + if len(self._key_columns) > 1 + else self._key_columns[0] + ): + # if it is a single key string, we want to just wrap it in a list + if isinstance(k, (str, int)): + k = [k] + backend = PandasTimeSeriesBackend( + k_df, + timestamp_column=self._timestamp_column, + value_columns=self._value_columns, + ids=k, + ) + result.append(dm.SingleTimeSeries(_backend=backend)) + + return result + + if name == "id_labels": + return self._key_columns + + # If requesting producer_id or ids, just return the stored value + if name == "producer_id": + return self._producer_id + + raise ValueError(f"Provided an attribute name that does not exist - {name}") + + +class PandasTimeSeriesBackend(TimeSeriesBackendBase): + """The PandasTimeSeriesBackend is responsible for managing the standard + in-memory representation of a TimeSeries + """ + + def __init__( + self, + data_frame: pd.DataFrame, + timestamp_column: str = None, + value_columns: Optional[Iterable[str]] = None, + ids: Optional[Union[Iterable[int], Iterable[str]]] = None, + ): + """At init time, hold onto the data frame as well as the arguments that + tell where the time and values live + + Args: + data_frame: pd.DataFrame + The raw data frame + timestamp_column: Optional[str] + The name of the column holding the timestamps. If set to None, timestamps will be + assigned based on the rows index (default is None) + value_columns: Optional[Iterable[str]] + A sequence of names of columns to hold as values + ids: Optional[iterable[int]] + A sequence of numeric IDs associated with this TimeSeries + """ + # Validate the types and column names + error.type_check("", pd.DataFrame, data_frame=data_frame) + error.type_check( + "", str, type(None), timestamp_column=timestamp_column + ) + error.type_check_all( + "", + str, + allow_none=True, + value_columns=value_columns, + ) + error.type_check_all( + "", + str, + np.int_, + int, + allow_none=True, + ids=ids, + ) + + # Validate the column names + + error.value_check( + "", + (timestamp_column is None or (timestamp_column in data_frame.columns)), + "Invalid timestamp column/index: {}", + timestamp_column, + ) + value_columns = value_columns or [ + col for col in data_frame.columns if col != timestamp_column + ] + error.value_check( + "", + # TODO: Support lambdas! + all(value_col in data_frame.columns for value_col in value_columns), + "Invalid value columns: {}", + value_columns, + ) + + self._df = data_frame + self._timestamp_column = timestamp_column + self._value_columns = value_columns + self._ids = [] if ids is None else ids + + # pylint: disable=too-many-return-statements + def get_attribute( + self, + data_model_class: Type["dm.SingleTimeSeries"], + name: str, + external_df: pd.DataFrame = None, + ) -> Any: + """When fetching a data attribute from the timeseries, this aliases to + the appropriate set of backend wrappers for the various fields. + """ + + # use the external definition of our pandas-like dataframe if + # requested + pandas_impl = external_df if external_df is not None else self._df + + if name == "timestamp_label": + return self._timestamp_column + + if name == "ids" and self._ids is not None and len(self._ids) != 0: + if isinstance(self._ids[0], (np.int_, int)): + val = data_model_class.IntIDSequence( + values=[ + id.item() if isinstance(id, np.int_) else id for id in self._ids + ] + ) + return DataBase.OneofFieldVal(val=val, which_oneof="id_int") + if isinstance(self._ids[0], str): + val = data_model_class.StringIDSequence(values=self._ids) + return DataBase.OneofFieldVal(val=val, which_oneof="id_str") + + # If requesting the value_labels, this is the value column names + if name == "value_labels": + return [str(val) for val in self._value_columns] + + # If requesting the "time_sequence" or one of the oneof fields, extract + # the timestamps from the dataframe + if name == "time_sequence": + if self._timestamp_column is None: + time_sequence = RangeIndex(start=0, stop=pandas_impl.shape[0], step=1) + else: + time_sequence = pandas_impl[self._timestamp_column] + + # If the sequence is periodic, use the PeriodicTimeSequence backend + is_periodic = isinstance(time_sequence.dtype, pd.PeriodDtype) or isinstance( + time_sequence, RangeIndex + ) + if is_periodic: + val = time_types.PeriodicTimeSequence.from_backend( + PandasPeriodicTimeSequenceBackend(time_sequence) + ) + return DataBase.OneofFieldVal(val=val, which_oneof="time_period") + # Otherwise, use the PointTimeSequence backend + val = time_types.PointTimeSequence.from_backend( + PandasPointTimeSequenceBackend(time_sequence) + ) + return DataBase.OneofFieldVal(val=val, which_oneof="time_points") + + # If requesting the value sequences, return the wrapped value columns + if name == "values": + return [ + time_types.ValueSequence.from_backend( + PandasValueSequenceBackend(pandas_impl, col_name) + ) + for col_name in self._value_columns + ] + + def as_pandas(self) -> Tuple[pd.DataFrame, str, Iterable[str]]: + """Return the underlying data frame""" + return self._df, self._timestamp_column, self._value_columns + + +class PandasValueSequenceBackend(UncachedBackendMixin, StrictFieldBackendMixin): + """Backend for ValueSequence backed by a set of columns in a Pandas + DataFrame + """ + + @staticmethod + def _serialize_any(any_val): + try: + json_str = json.dumps(any_val) + return json_str + except Exception as exc: + raise TypeError("could not serialize the given value") from exc + + # This dtype is what shows up for non-periodic date ranges + _TIMESTAMP_DTYPE = np.dtype("datetime64[ns]") + + # What types do we consider to be vector types + _DEFAULT_VECTOR_TYPES = [list, np.ndarray] + if HAVE_PYSPARK: + # pyspark.pandas.DataFrame objects can contain + # pyspark specific objects + + # Third Party + # pylint: disable=import-outside-toplevel + from pyspark.ml.linalg import Vector as SVector + + _DEFAULT_VECTOR_TYPES.append(SVector) + + def __init__(self, data_frame: pd.Series, col_name: str): + """Initialize with the data frame and the value column name""" + self._df = data_frame + self._col_name = col_name + # Determine which of the oneof types is valid for this sequence + self._dtype = self._df[self._col_name].dtype + self._converter = lambda x: x + if str(self._dtype).startswith( + str(self.__class__._TIMESTAMP_DTYPE)[:-1] + ) or isinstance(self._dtype, pd.PeriodDtype): + # what do we want to do here, are we just assuming it will convert forever? + self._sequence_type = time_types.ValueSequence.TimePointSequence + self._valid_oneof = "val_timepoint" + # todo not sure why np.issubdtype is running into issue if this is run after + elif self._dtype == "string": + self._sequence_type = time_types.ValueSequence.StrValueSequence + self._valid_oneof = "val_str" + elif np.issubdtype(self._dtype, np.integer): + self._sequence_type = time_types.ValueSequence.IntValueSequence + self._valid_oneof = "val_int" + elif np.issubdtype(self._dtype, np.floating): + self._sequence_type = time_types.ValueSequence.FloatValueSequence + self._valid_oneof = "val_float" + # todo do we handle ndarrays in cells, if so we need to convert to list before going to json + # as ndarray is not serializable + # this is making the assumption that we have at least one value in the dataframe + elif str(self._dtype) == "object" and isinstance( + self._df[self._col_name].iloc[0], + tuple(PandasValueSequenceBackend._DEFAULT_VECTOR_TYPES), + ): + self._sequence_type = time_types.ValueSequence.VectorValueSequence + self._valid_oneof = "val_vector" + else: + self._sequence_type = time_types.ValueSequence.AnyValueSequence + self._valid_oneof = "val_any" + + def get_attribute( + self, + data_model_class: Type[time_types.ValueSequence], + name: str, + ) -> Any: + """Get the known attributes from the underlying DataFrame columns""" + + if name == "sequence": + name = self._valid_oneof + if name == self._valid_oneof and name in [ + "val_int", + "val_float", + "val_str", + "val_vector", + ]: + return self._sequence_type( + values=[ + self._converter(val) + for val in iteritems_workaround( + self._df[self._col_name], force_list=True + ) + ], + ) + if name == self._valid_oneof == "val_any": + return self._sequence_type( + values=[ + self._serialize_any(val) + for val in iteritems_workaround( + self._df[self._col_name], force_list=False + ) + ] + ) + + if name == self._valid_oneof == "val_timepoint": + return self._sequence_type( + values=[ + val.isoformat() if hasattr(val, "isoformat") else str(val) + for val in iteritems_workaround( + self._df[self._col_name], force_list=False + ) + ] + ) + + # Delegate to common parent logic + return super().get_attribute(data_model_class, name) + + +class PandasPeriodicTimeSequenceBackend(UncachedBackendMixin, StrictFieldBackendMixin): + """Backend for PeriodicTimeSequence backed by a Pandas Time Span""" + + def __init__(self, time_sequence): + """Initialize with a periodic time sequence""" + self._is_range_index = isinstance(time_sequence, RangeIndex) + if self._is_range_index: + self._start_time = time_sequence.start + self._period_length = time_sequence.step + else: + self._start_time = ( + None if time_sequence.empty else time_sequence.iloc[0].start_time + ) + self._period_length = time_sequence.dtype.freq.name + + def get_attribute( + self, + data_model_class: Type[time_types.PeriodicTimeSequence], + name: str, + ) -> Any: + """Get the known attributes from the backend data""" + if name == "start_time" and self._start_time is not None: + return time_types.TimePoint.from_backend( + PandasTimePointBackend(self._start_time) + ) + if name == "period_length": + if self._is_range_index: + return time_types.TimeDuration(dt_int=self._period_length) + + return time_types.TimeDuration(dt_str=self._period_length) + + # Delegate to common parent logic + return super().get_attribute(data_model_class, name) + + +class PandasPointTimeSequenceBackend( + UncachedBackendMixin, + StrictFieldBackendMixin, +): # TODO: Should we cache this one??? + """Backend for PointTimeSequence backed by a Pandas Series""" + + def __init__(self, time_sequence: pd.Series): + """Initialize with a series based time sequence""" + self._time_sequence = time_sequence + + def get_attribute( + self, + data_model_class: Type[time_types.PointTimeSequence], + name: str, + ) -> Any: + """Get the known attributes from the backend data""" + if name == "points": + # TODO: a user may have ints/floats stored as objects in their dataframe, should we + # handle that or throw an exception + return [ + time_types.TimePoint.from_backend(PandasTimePointBackend(point_data)) + for point_data in iteritems_workaround( + self._time_sequence, force_list=True + ) + ] + + # Delegate to common parent logic + return super().get_attribute(data_model_class, name) + + +class PandasTimePointBackend(UncachedBackendMixin, StrictFieldBackendMixin): + """Backend for time point data held by Pandas""" + + def __init__(self, point_data: Any): + """Initialize with the raw pandas value""" + self._point_data = point_data + + def get_attribute( + self, + data_model_class: Type[time_types.TimePoint], + name: str, + ) -> Any: + """Get the appropriate fields based on the data type of the point""" + int_ok = name in ["time", "ts_int"] + float_ok = name in ["time", "ts_float"] + epoch_ok = name in ["time", "ts_epoch"] + + if epoch_ok and isinstance( + self._point_data, (pd.Timestamp, datetime, np.datetime64, pd.Period) + ): + return time_types.Seconds(seconds=pd_timestamp_to_seconds(self._point_data)) + dtype = getattr(self._point_data, "dtype", None) + if int_ok and ( + isinstance(self._point_data, int) or np.issubdtype(dtype, np.integer) + ): + return self._point_data + if float_ok and ( + isinstance(self._point_data, float) + or (dtype is not None and np.issubdtype(dtype, np.floating)) + ): + return self._point_data diff --git a/caikit/interfaces/ts/data_model/backends/spark_util.py b/caikit/interfaces/ts/data_model/backends/spark_util.py new file mode 100644 index 000000000..ab6dc96fa --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/spark_util.py @@ -0,0 +1,79 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities for supporting spark backend implementations""" + +# Standard +from typing import Any, Iterable, List + +# Third Party +import pandas as pd + +# Local +from ..toolkit.optional_dependencies import HAVE_PYSPARK, pyspark + + +def iteritems_workaround(series: Any, force_list: bool = False) -> Iterable: + """pyspark.pandas.Series objects do not support + iteration. For native pandas.Series objects this + function will be a no-op. + + For pyspark.pandas.Series or other iterable objects + we try to_numpy() (unless force_list + is true) and if that fails we resort to a to_list() + + """ + + # check that we can convert + if not hasattr(series, "to_list") and not hasattr(series, "to_numpy"): + raise NotImplementedError( + f"invalid typed {type(series)} passed for parameter series" + ) + + if isinstance(series, pd.Series): + return series + + # handle an edge case of pyspark.ml.linalg.DenseVector + if ( + HAVE_PYSPARK + and isinstance(series, pyspark.pandas.series.Series) + and isinstance(series[0], pyspark.ml.linalg.Vector) + ): + return [x.toArray().tolist() for x in series.to_numpy()] + + # note that we're forcing a list only if we're not + # a native pandas series + if force_list: + return series.to_list() + + try: + return series.to_numpy() + except: # noqa: E722 + return series.to_list() + + +def mock_pd_groupby(a_df_like, by: List[str], return_pandas_api=False): + """Roughly mocks the behavior of pandas groupBy but on a spark dataframe.""" + + distinct_keys = a_df_like.select(by).distinct().collect() + for dkey in distinct_keys: + adict = dkey.asDict() + filter_statement = "" + for k, v in adict.items(): + filter_statement += f" {k} == '{v}' and" + if filter_statement.endswith("and"): + filter_statement = filter_statement[0:-3] + sub_df = a_df_like.filter(filter_statement) + value = tuple(adict.values()) + value = value[0] if len(value) == 1 else value + yield value, sub_df.pandas_api() if return_pandas_api else sub_df diff --git a/caikit/interfaces/ts/data_model/backends/util.py b/caikit/interfaces/ts/data_model/backends/util.py new file mode 100644 index 000000000..627e0970c --- /dev/null +++ b/caikit/interfaces/ts/data_model/backends/util.py @@ -0,0 +1,101 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities for supporting backend implementations""" + +# Standard +from datetime import datetime +from typing import Union + +# Third Party +import numpy as np +import pandas as pd + + +def timezoneoffset(adatetime: datetime) -> int: + """Returns the timezone offset (in seconds) + for a given datetime object relative to the local + system's time. + + Args: + adatetime (datetime): a date of interest. + + Returns: + int: offset in seconds (can be negative) + """ + return ( + adatetime.timestamp() + - datetime( + year=adatetime.year, + month=adatetime.month, + day=adatetime.day, + hour=adatetime.hour, + minute=adatetime.minute, + second=adatetime.second, + microsecond=adatetime.microsecond, + ).timestamp() + ) + + +def pd_timestamp_to_seconds(ts) -> float: + """Extract the seconds-since-epoch representation of the timestamp + + NOTE: The pandas Timestamp.timestamp() function returns a different value + than Timestamp.to_pydatetime().timestamp()! Since we want this to + round-trip with python datetime, we want the latter. They both claim to + be POSIX, so something is missing leap-something! + """ + if isinstance(ts, pd.Period): + return ts.to_timestamp().timestamp() # no utc shift + if isinstance(ts, np.datetime64): + return ts.astype("datetime64[ns]").astype(float) / 1e9 + if isinstance(ts, datetime): + return ts.timestamp() + if isinstance(ts, (int, float, np.int32, np.int64, np.float32, np.float64)): + return float(ts) + raise ValueError(f"invalid type {type(ts)} for parameter ts.") + + +def strip_periodic( + input_df: pd.DataFrame, ts_col_name: Union[str, None] = None, create_copy=True +) -> pd.DataFrame: + """ + Removes **the first instance** of a periodic timestamp info + (because spark doesn't like these when constructing a pyspark.sql.DataFrame.) + If no periodic timestamp values can be found, input_df is returned as is. + This method is always a no-op if input_df is not a native pandas.DataFrame. + """ + + if not isinstance(input_df, pd.DataFrame): + return input_df + + # find location of period field + try: + index = ( + [type(x) for x in input_df.dtypes].index(pd.core.dtypes.dtypes.PeriodDtype) + if ts_col_name is None + else input_df.columns.to_list().index(ts_col_name) + ) + except ValueError: + index = -1 + + df = input_df + if index >= 0: + df = input_df if not create_copy else input_df.copy(deep=False) + # df.iloc[:, index] + df[df.columns[index]] = [ + x.to_timestamp() if hasattr(x, "to_timestamp") else x + for x in df.iloc[:, index] + ] + + return df diff --git a/caikit/interfaces/ts/data_model/package.py b/caikit/interfaces/ts/data_model/package.py new file mode 100644 index 000000000..3685c06eb --- /dev/null +++ b/caikit/interfaces/ts/data_model/package.py @@ -0,0 +1,19 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from caikit.core.data_model import CAIKIT_DATA_MODEL + +# Shared package name constant for all TS data model objects +TS_PACKAGE = f"{CAIKIT_DATA_MODEL}.timeseries" diff --git a/caikit/interfaces/ts/data_model/time_types.py b/caikit/interfaces/ts/data_model/time_types.py new file mode 100644 index 000000000..063204a7d --- /dev/null +++ b/caikit/interfaces/ts/data_model/time_types.py @@ -0,0 +1,226 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The core data model objects for primitive time types +""" + +# Standard +from datetime import datetime, timedelta, timezone +from functools import lru_cache +from typing import List, Tuple, Union +import json + +# Third Party +import numpy as np + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber, OneofField +import alog + +# Local +from ....core import DataObjectBase +from ....core.data_model import dataobject +from ....core.exceptions import error_handler +from .package import TS_PACKAGE + +log = alog.use_channel("TSDM") +error = error_handler.get(log) + + +@dataobject(package=TS_PACKAGE) +class Seconds(DataObjectBase): + """A nanosecond value that can be interpreted as either a datetime or a + timedelta + """ + + seconds: Annotated[float, FieldNumber(1)] + + def as_datetime(self) -> datetime: + """Return a python datetime object. + The returned object will have timezone.utc set as its timezone info.""" + return datetime.fromtimestamp(self.seconds, tz=timezone.utc) + + def as_timedelta(self) -> timedelta: + """Interpret these nanoseconds as a duration""" + return timedelta(seconds=self.seconds) + + @classmethod + def from_datetime(cls, time_point: datetime) -> "Seconds": + """Create a Seconds from a datetime""" + return cls(seconds=time_point.timestamp()) + + @classmethod + def from_timedelta(cls, time_delta: timedelta) -> "Seconds": + """Create a Seconds from a timedelta""" + return cls(seconds=time_delta.total_seconds()) + + +@dataobject(package=TS_PACKAGE) +class TimePoint(DataObjectBase): + """ + The core data model object for a TimePoint + """ + + time: Union[ + Annotated[int, OneofField("ts_int"), FieldNumber(1)], + Annotated[float, OneofField("ts_float"), FieldNumber(2)], + Annotated[Seconds, OneofField("ts_epoch"), FieldNumber(3)], + ] + + +@dataobject(package=TS_PACKAGE) +class TimeDuration(DataObjectBase): + """ + The core data model object for a TimeDuration + """ + + time: Union[ + Annotated[int, OneofField("dt_int"), FieldNumber(1)], + Annotated[float, OneofField("dt_float"), FieldNumber(2)], + Annotated[str, OneofField("dt_str"), FieldNumber(3)], + Annotated[Seconds, OneofField("dt_sec"), FieldNumber(4)], + ] + + +@dataobject(package=TS_PACKAGE) +class PeriodicTimeSequence(DataObjectBase): + """A PeriodicTimeSequence represents an indefinite time sequence where ticks + occur at a regular period + """ + + start_time: Annotated[TimePoint, FieldNumber(1)] + period_length: Annotated[TimeDuration, FieldNumber(2)] + + +@dataobject(package=TS_PACKAGE) +class PointTimeSequence(DataObjectBase): + """A PointTimeSequence represents a finite sequence of time points that may + or may not be evenly distributed in time + """ + + points: Annotated[List[TimePoint], FieldNumber(1)] + + +@dataobject(package=TS_PACKAGE) +class Vector(DataObjectBase): + """A vector represents a finite sequence of doubles""" + + data: Annotated[List[float], FieldNumber(1)] + + +@dataobject(package=TS_PACKAGE) +class ValueSequence(DataObjectBase): + """A ValueSequence is a finite list of contiguous values, each representing + the value of a given attribute for a specific observation within a + TimeSeries + """ + + @dataobject(package=TS_PACKAGE) + class IntValueSequence(DataObjectBase): + """Nested value sequence of integers""" + + values: Annotated[List[int], FieldNumber(1)] + + @dataobject(package=TS_PACKAGE) + class FloatValueSequence(DataObjectBase): + """Nested value sequence of floats""" + + values: Annotated[List[float], FieldNumber(1)] + + @dataobject(package=TS_PACKAGE) + class StrValueSequence(DataObjectBase): + """Nested value sequence of strings""" + + values: Annotated[List[str], FieldNumber(1)] + + @dataobject(package=TS_PACKAGE) + class VectorValueSequence(DataObjectBase): + """Nested value sequence of vectors""" + + values: Annotated[List[Vector], FieldNumber(1)] + + def __post_init__(self): + error.type_check("", list, values=self.values) + error.type_check_all( + "", list, np.ndarray, Vector, values=self.values + ) + + def _convert_np_to_list(self, v): + return v.tolist() + + def to_dict(self): + result = [] + for v in self.values: + v_in = self._convert_np_to_list(v) if isinstance(v, np.ndarray) else v + result.append({"data": v_in if isinstance(v_in, list) else v.data}) + return {"values": result} + + def fill_proto(self, proto): + subproto = proto.values + subproto.extend( + [ + v.to_proto() + if isinstance(v, Vector) + else Vector( + v if isinstance(v, list) else self._convert_np_to_list(v) + ).to_proto() + for v in self.values + ] + ) + return proto + + @classmethod + def from_proto(cls, proto): + return cls(**{"values": [list(v.data) for v in proto.values]}) + + # todo we can have a constuct for sequences that require serialization + @dataobject(package=TS_PACKAGE) + class TimePointSequence(DataObjectBase): + """Nested value sequence of TimePoints""" + + values: Annotated[List[str], FieldNumber(1)] + + # todo we can have a construct for sequences that require serialization + @dataobject(package=TS_PACKAGE) + class AnyValueSequence(DataObjectBase): + """Nested value sequence of Any objects""" + + values: Annotated[List[str], FieldNumber(1)] + + @classmethod + @lru_cache(maxsize=None) + def decode_values(cls, values: Tuple[str]): + """Cached class method to enable caching of decoded representations""" + return [json.loads(v) for v in values] + + def to_dict(self): + return {"values": self.__class__.decode_values(tuple(self.values))} + + def fill_proto(self, proto): + subproto = proto.values + subproto.extend(self.__class__.decode_values(tuple(self.values))) + return proto + + @classmethod + def from_proto(cls, proto): + return cls(**{"values": [json.dumps(v) for v in proto.values]}) + + sequence: Union[ + Annotated[IntValueSequence, OneofField("val_int"), FieldNumber(1)], + Annotated[FloatValueSequence, OneofField("val_float"), FieldNumber(2)], + Annotated[StrValueSequence, OneofField("val_str"), FieldNumber(3)], + Annotated[TimePointSequence, OneofField("val_timepoint"), FieldNumber(4)], + Annotated[AnyValueSequence, OneofField("val_any"), FieldNumber(5)], + Annotated[VectorValueSequence, OneofField("val_vector"), FieldNumber(6)], + ] diff --git a/caikit/interfaces/ts/data_model/timeseries.py b/caikit/interfaces/ts/data_model/timeseries.py new file mode 100644 index 000000000..ecfc240c6 --- /dev/null +++ b/caikit/interfaces/ts/data_model/timeseries.py @@ -0,0 +1,348 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Iterable, List, Optional, Tuple + +# Third Party +import numpy as np +import pandas as pd + +# First Party +import alog + +# Local +from ....core import DataObjectBase +from ....core.data_model import ProducerId, dataobject +from ....core.exceptions import error_handler +from ._single_timeseries import SingleTimeSeries +from .backends.base import MultiTimeSeriesBackendBase +from .backends.pandas_backends import PandasMultiTimeSeriesBackend +from .backends.util import strip_periodic +from .package import TS_PACKAGE +from .toolkit.optional_dependencies import HAVE_PYSPARK, pyspark +from .toolkit.sparkconf import sparkconf_local + +log = alog.use_channel("TSDM") +error = error_handler.get(log) + + +@dataobject(package=TS_PACKAGE) +class TimeSeries(DataObjectBase): + timeseries: List[SingleTimeSeries] + id_labels: List[str] + producer_id: ProducerId + + _DEFAULT_ID_COL = "_TS_RESERVED" + _DEFAULT_TS_COL = "timestamp" + + def __init__(self, *args, **kwargs): + """Constructing a TimeSeries will currently delegate + to either a pandas or spark dataframe backend depending + on whether a native pandas or spark dataframe are passed for + the first argument respectively. + """ + + if "timeseries" in kwargs: + self.timeseries = None + self.id_labels = None + self.producer_id = None + is_multi = True + for k, v in kwargs.items(): + if k == "timeseries" and not isinstance(v, list): + is_multi = False + setattr(self, k, [v]) + else: + setattr(self, k, v) + + # if id_labels was never set, that means we have a single timeseries + if not is_multi: + self.id_labels = [] + else: + error.value_check( + "", + len(args) != 0, + "must have at least the data argument", + args, + ) + data_arg = args[0] + + # This will be done if SingleTimeSeries + if kwargs.get("key_column") is None: + kwargs["key_column"] = [] + + if isinstance(data_arg, pd.DataFrame): + self._backend = PandasMultiTimeSeriesBackend(*args, **kwargs) + elif HAVE_PYSPARK and isinstance(data_arg, pyspark.sql.DataFrame): + # Local + # pylint: disable=import-outside-toplevel + from ..data_model.backends._spark_backends import ( + SparkMultiTimeSeriesBackend, + ) + + self._backend = SparkMultiTimeSeriesBackend(*args, **kwargs) + + def __len__(self) -> int: + """Return the length of the time series object. + + Returns: + int: Length + """ + backend = getattr(self, "_backend", None) + + if backend is None: + if self.timeseries: + return sum(len(ts) for ts in self.timeseries) + return 0 + + if HAVE_PYSPARK: + # Local + # pylint: disable=import-outside-toplevel + from ..data_model.backends._spark_backends import ( + SparkMultiTimeSeriesBackend, + ) + + if isinstance(backend, PandasMultiTimeSeriesBackend): + return len(backend._df) + if HAVE_PYSPARK and isinstance(self._backend, SparkMultiTimeSeriesBackend): + return backend._pyspark_df.count() + + error.log_raise( + "", + f"Unknown backend {type(backend)}", + ) # pragma: no cover + + def __eq__(self, other: "TimeSeries") -> bool: + """Equivalence operator for TimeSeries objects. + + Args: + other (TimeSeries): TimeSeries to test against. + + Returns: + bool: True if the TimeSeries are equivalent. + """ + + # if number of mts is different, always unequal + if len(self.timeseries) != len(other.timeseries): + return False + + # empty mts is equal + if len(self.timeseries) == 0: + # ignoring edge cases of empty mts with different columns + # unclear if this is even possible + return True # pragma: no cover + + # degenerate case + if len(self.timeseries) == 1: + return self.timeseries[0] == other.timeseries[0] + + # create map between keys and time series + left_id_map = {tuple(ts.ids.values): ts for ts in self.timeseries} + right_id_map = {tuple(ts.ids.values): ts for ts in other.timeseries} + + # quickly check keys are identical + if set(left_id_map.keys()) != set(right_id_map.keys()): + return False + + return all(l_ts == right_id_map[l_key] for l_key, l_ts in left_id_map.items()) + + def _get_pd_df(self) -> Tuple[pd.DataFrame, Iterable[str], str, Iterable[str]]: + """Convert the data to a pandas DataFrame, efficiently if possible""" + + # If there is a backend that knows how to do the conversion, use that + backend = getattr(self, "_backend", None) + if backend is not None and isinstance(backend, MultiTimeSeriesBackendBase): + log.debug("Using backend pandas conversion") + return backend.as_pandas() + + error.value_check( + "", + self.timeseries is not None, + "Cannot create pandas data frame without any timeseries present", + ) + + error.value_check( + "", + self.id_labels is not None, + "Cannot create pandas data frame without any key labels present", + ) + + key_columns = self.id_labels + dfs = [] + value_columns = None + timestamp_column = None + for ts in self.timeseries: # pylint: disable=not-an-iterable + if value_columns is None: + value_columns = ts.value_labels + if ts.timestamp_label != "": + timestamp_column = ts.timestamp_label + df = ts._get_pd_df()[0] + + for i, key_col in enumerate(key_columns): + id_val = ts.ids.values[i] + df[key_col] = [id_val] * df.shape[0] + dfs.append(df) + ignore_index = True # timestamp_column != "" + result = pd.concat(dfs, ignore_index=ignore_index) + self._backend = PandasMultiTimeSeriesBackend( + result, + key_column=key_columns, + timestamp_column=timestamp_column, + value_columns=value_columns, + ) + + return ( + result, + key_columns, + timestamp_column, + value_columns, + ) + + def as_pandas( + self, include_timestamps: Optional[bool] = None, is_multi: Optional[bool] = None + ) -> "pd.DataFrame": + """Get the view of this timeseries as a pandas DataFrame + + Args: + include_timestamps (bool, optional): Control the addition or removal of + timestamps. True will include timestamps, generating if needed, while False will + remove timestamps. Use None to returned what is available, leaving unchanged. + Defaults to None. + + is_multi (bool, optional): Controls how id_labels are handled in the output. If + the id_labels are specified in the data model, they are always returned. If there + are no id_labels specified, setting is_multi to True will add a new column with + generated id labels (0), while False or None will not add any id_labels. + + Returns: + pd.DataFrame: The view of the data as a pandas DataFrame + """ + # if as_pandas is_multi is True, and timeseries is_multi is False => add a RESERVED id + # column with constant value + # if as_pandas is_multi is True, and timeseries is_multi is True => do nothing just return + # as is + # if as_pandas is_multi is False, and timeseries is_multi is True => remove the id columns + # if as_pandas is_multi is False, and timeseries is_multi is False => do nothing just + # return as is + # if as_pandas is_multi is None => do nothing just return as is + if len(self.id_labels) == 0: + # pylint: disable=unsubscriptable-object + df = self.timeseries[0].as_pandas(include_timestamps=include_timestamps) + + # add a RESERVED id column with constant value + if is_multi is not None and is_multi: + df = df.copy(deep=True) + df[self.__class__._DEFAULT_ID_COL] = np.zeros(len(df), dtype=np.int32) + return df + + backend_df = self._get_pd_df()[0] + timestamp_column = self._backend._timestamp_column + + # if we want to include timestamps, but it is not already in the dataframe, we need to + # add it + if include_timestamps and timestamp_column is None: + backend_df = backend_df.copy() # avoid mutating original + ts_column = self.__class__._DEFAULT_TS_COL + backend_df[ts_column] = [0] * len(backend_df) + backend_df[ts_column] = backend_df.groupby( + self._backend._key_column, sort=False + )[ts_column].transform(lambda x: list(range(len(x)))) + return backend_df + # if we do not want timestamps, but we already have them in the dataframe, we need to + # return a view without timestamps + if ( + include_timestamps is not None and not include_timestamps + ) and timestamp_column is not None: + return backend_df.loc[:, backend_df.columns != timestamp_column] + + return backend_df + + def as_spark( + self, include_timestamps: Optional[bool] = None, is_multi: Optional[bool] = None + ) -> "pyspark.sql.DataFrame": + """Get the view of this timeseries as a spark DataFrame + + Args: + include_timestamps (bool, optional): Control the addition or removal of + timestamps. True will include timestamps, generating if needed, while False will + remove timestamps. Use None to returned what is available, leaving unchanged. + Defaults to None. + + is_multi (bool, optional): Controls how id_labels are handled in the output. If + the id_labels are specified in the data model, they are always returned. If there + are no id_labels specified, setting is_multi to True will add a new column with + generated id labels (0), while False or None will not add any id_labels. + + Returns: + pyspark.sql.DataFrame: The view of the data as a spark DataFrame + """ + if not HAVE_PYSPARK: + raise NotImplementedError("pyspark must be available to use this method.") + + # todo: is this right??? + if len(self.id_labels) == 0: + # pylint: disable=unsubscriptable-object + df = self.timeseries[0].as_spark(include_timestamps=include_timestamps) + # add a RESERVED id column with constant value + if is_multi is not None and is_multi: + df = df.pandas_api() + df = df.copy(deep=True) + df[self.__class__._DEFAULT_ID_COL] = np.zeros( + len(df), dtype=np.int32 + ).tolist() + df = df.to_spark() + return df + + # Third Party + # pylint: disable=import-outside-toplevel + from pyspark.sql import SparkSession + + # Local + # pylint: disable=import-outside-toplevel + from ..data_model.backends._spark_backends import SparkMultiTimeSeriesBackend + + # If there is a backend that knows how to do the conversion, use that + backend = getattr(self, "_backend", None) + if backend is not None and isinstance(backend, SparkMultiTimeSeriesBackend): + answer = backend._pyspark_df + timestamp_column = backend._timestamp_column + if include_timestamps and timestamp_column is None: + + def append_timestamp_column(aspark_df, key_cols, timestamp_name): + sql = ( + f"row_number() OVER (PARTITION BY {','.join(key_cols)} " + f"ORDER BY {','.join(key_cols)}) -1 as {timestamp_name}" + ) + return aspark_df.selectExpr("*", sql) + + answer = append_timestamp_column( + answer, key_cols=self.id_labels, timestamp_name="timestamp" + ) + elif ( + include_timestamps is not None + and not include_timestamps + and timestamp_column is not None + ): + answer = answer.drop(timestamp_column) + return answer + + pdf = strip_periodic( + self.as_pandas(include_timestamps=include_timestamps), + create_copy=True, + ) + return ( + SparkSession.builder.config(conf=sparkconf_local()) + .getOrCreate() + .createDataFrame(pdf) + ) diff --git a/caikit/interfaces/ts/data_model/timeseries_evaluation.py b/caikit/interfaces/ts/data_model/timeseries_evaluation.py new file mode 100644 index 000000000..71769922a --- /dev/null +++ b/caikit/interfaces/ts/data_model/timeseries_evaluation.py @@ -0,0 +1,219 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The core data model object for a TimeSeries Evaluator. +""" +# Standard +from typing import List, Union + +# Third Party +import pandas as pd + +# First Party +from py_to_proto.dataclass_to_proto import ( # Annotated imported from here for compatibility + Annotated, + FieldNumber, + OneofField, +) +import alog + +# Local +from ....core import DataObjectBase +from ....core.data_model import ProducerId, dataobject +from ....core.exceptions import error_handler +from .package import TS_PACKAGE + +log = alog.use_channel("TSEDM") +error = error_handler.get(log) + +## TimeSeries Evaluator ################################################################## + + +@dataobject(package=TS_PACKAGE) +class Id(DataObjectBase): + """A single instance of Id + Representation of ids that can be either text or index. Customized + this way to be able to work with repeated + """ + + value: Union[ + Annotated[str, OneofField("text"), FieldNumber(1)], + Annotated[int, OneofField("index"), FieldNumber(2)], + ] + + +@dataobject(package=TS_PACKAGE) +class EvaluationRecord(DataObjectBase): + """A single EvaluationRecord for EvaluationResult + Representation of EvaluationRecord for each row in the dataframe + EvaluationRecord{id_values=["A", "B"], metric_values=[0.234, 0.568, 0.417], offset="overall"} + """ + + id_values: Annotated[List[Id], FieldNumber(1)] + metric_values: Annotated[List[float], FieldNumber(2)] + offset: Annotated[Id, FieldNumber(3)] + + def __init__(self, id_values=None, metric_values=None, offset=None): + """Construct a new EvaluationRecord instance + + EvaluationRecord + + Args: + id_values: list(Id) + List of Id values for the record + metric_values: list(float) + List of Id values containing metric results for the record + offset: (optional) Id + offset associated with the record + """ + + error.type_check_all( + "", str, int, Id, allow_none=True, id_values=id_values + ) + error.type_check_all("", float, metric_values=metric_values) + error.type_check("", str, int, Id, allow_none=True, offset=offset) + + super().__init__() + + self.id_values = ( + [] + if id_values is None + else [ + Id(id_value) if not isinstance(id_value, Id) else id_value + for id_value in id_values + ] + ) + + self.metric_values = metric_values + + self.offset = ( + None + if offset is None + else Id(offset) + if not isinstance(offset, Id) + else offset + ) + + +@dataobject(package=TS_PACKAGE) +class EvaluationResult(DataObjectBase): + """EvaluationResult containing the evaluation results + Representation of EvaluationResult stores rows of the dataframe as list of records string lists + to keep track of id and metric columns + """ + + records: Annotated[List[EvaluationRecord], FieldNumber(1)] + id_cols: Annotated[List[str], FieldNumber(2)] + metric_cols: Annotated[List[str], FieldNumber(3)] + offset_col: Annotated[str, FieldNumber(4)] + producer_id: Annotated[ProducerId, FieldNumber(5)] + + def __init__( + self, + records=None, + id_cols=None, + metric_cols=None, + offset_col=None, + df=None, + producer_id=None, + ): + """Construct a new EvaluationResult instance + + EvaluationResult + + Args: + records: list(EvaluationRecord) + List of Evaluation Record instances + id_cols: list(string) + List of string containing id column names (Optional) + metric_cols: list(string) + List of string containing metric value column names + offset_col: string + Name of offset column in dataframe if exists (Optional) + df: pandas dataframe + initial input dataframe from which to store the results + producer_id: ProducerId | None + The module that produced this evaluation result. + """ + + error.type_check_all("", str, allow_none=True, id_cols=id_cols) + error.type_check_all("", str, metric_cols=metric_cols) + error.type_check("", str, allow_none=True, offset_col=offset_col) + error.type_check( + "", + tuple, + ProducerId, + allow_none=True, + producer_id=producer_id, + ) + + super().__init__() + + self.id_cols = [] if id_cols is None else id_cols + self.metric_cols = metric_cols + self.offset_col = offset_col + self.producer_id = producer_id + + if df is not None: + if self.offset_col is not None: + error.value_check( + "", + self.offset_col in df.columns, + f"Specified '{self.offset_col}' offset column not in dataframe", + ) + + self.records = [ + EvaluationRecord( + id_values=( + None + if len(self.id_cols) == 0 + else df.loc[i, self.id_cols].values.tolist() + ), + metric_values=df.loc[i, self.metric_cols].values.tolist(), + offset=( + None if self.offset_col is None else df.loc[i, self.offset_col] + ), + ) + for i in range(len(df)) + ] + else: + error.type_check_all("", EvaluationRecord, records=records) + self.records = records + + def as_pandas(self) -> "pd.DataFrame": + """Generate and return a pandas DataFrame""" + + records = [] + + has_offset = False + for record in self.records: + id_values = [] + metric_values = [] + offset = None + + id_values = [v.value for v in record.id_values] + metric_values = record.metric_values + if record.offset: + offset = record.offset.value + has_offset = True + + records.append(id_values + metric_values + [offset]) + + df = pd.DataFrame( + records, columns=self.id_cols + self.metric_cols + [self.offset_col] + ) + if not has_offset: + df.drop([self.offset_col], axis=1, inplace=True) + + return df diff --git a/caikit/interfaces/ts/data_model/toolkit/__init__.py b/caikit/interfaces/ts/data_model/toolkit/__init__.py new file mode 100644 index 000000000..2068258bf --- /dev/null +++ b/caikit/interfaces/ts/data_model/toolkit/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/caikit/interfaces/ts/data_model/toolkit/optional_dependencies.py b/caikit/interfaces/ts/data_model/toolkit/optional_dependencies.py new file mode 100644 index 000000000..bf830e494 --- /dev/null +++ b/caikit/interfaces/ts/data_model/toolkit/optional_dependencies.py @@ -0,0 +1,95 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module encapsulates the core optional dependencies using lazy imports. +""" + +# Standard +from types import ModuleType +from typing import Optional +import importlib + +# First Party +import alog + +log = alog.use_channel("OPTDEP") + + +## Implementation ############################################################## + + +class LazyModule(ModuleType): + """A LazyModule is a module subclass that wraps another module but imports + it lazily and then aliases __getattr__ to the lazily imported module. + """ + + def __init__(self, name: str, package: Optional[str] = None): + """Hang onto the import args to use lazily""" + self.__name = name + self.__package = package + self.__wrapped_module = None + + def __getattr__(self, name: str) -> any: + """When asked for an attribute, make sure the wrapped module is imported + and then delegate + """ + if self.__wrapped_module is None: + log.debug1("Triggering lazy import for %s.%s", self.__package, self.__name) + self.__wrapped_module = importlib.import_module( + self.__name, + self.__package, + ) + return getattr(self.__wrapped_module, name) + + +def have_module(name: str, package: Optional[str] = None) -> bool: + """This method can be used to check whether a given optional dependency is + available and should primarily be used for assertions when coding + defensively. + + NOTE: Nested modules WILL force the import of parent modules + + TODO: Move this to import_tracker + + Args: + name: str + The name of the module + package: Optional[str] + The qualifying package for the module under investigation + + Returns: + have_module: bool + True if the module can be imported, False otherwise + """ + spec = importlib.util.find_spec(name, package) + return ( + # No spec found under standard import + spec is not None + and spec.loader is not None + and + # Spec not found and delegated to import_tracker lazy failures + spec.loader.__module__.split(".")[0] != "import_tracker" + ) + + +## Public ###################################################################### + +# The core optional dependencies +pd = LazyModule("pandas") +pyspark = LazyModule("pyspark") + +# Import-time checks for the presence of optional dependencies +HAVE_NUMPY = have_module("numpy") +HAVE_PANDAS = have_module("pandas") +HAVE_PYSPARK = have_module("pyspark") diff --git a/caikit/interfaces/ts/data_model/toolkit/sparkconf.py b/caikit/interfaces/ts/data_model/toolkit/sparkconf.py new file mode 100644 index 000000000..bebc78dab --- /dev/null +++ b/caikit/interfaces/ts/data_model/toolkit/sparkconf.py @@ -0,0 +1,191 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines function(s) for obtaining a spark configurations.""" + +# Standard +from typing import Union +import socket + +# Local +from .optional_dependencies import HAVE_PYSPARK + +WE_HAVE_PYSPARK = HAVE_PYSPARK + +if WE_HAVE_PYSPARK: + # Third Party + from pyspark import SparkConf + + +def sparkconf_local( + master: str = "local[2]", + executor_memory: str = "2g", + driver_memory: str = "2g", + app_name: str = "unnamed", + **kwargs, +): + """Returns a SparkConf object configured for spark-local operation + + Args: + executor_memory (str, optional): Exectuor memory. Defaults to "2g". + driver_memory (str, optional): Driver memory. Defaults to "2g". + app_name (str, optional): Spark application name. Defaults to "unnamed". + kwargs: passthru key,value arguments that will be added to the spark configuration + + Returns: + SparkConf: a spark configuration object. + """ + + if not WE_HAVE_PYSPARK: + return {} + + if master.find("local[") != 0: + raise ValueError( + "master for local session must be in form 'local[N]' where N is either an integer or *" + ) + + return sparkconf_k8s( + master=master, + executor_memory=executor_memory, + driver_memory=driver_memory, + app_name=app_name, + namespace="foo", + driver_image="foo", + executor_image="foo", + **kwargs, + ) + + +# pylint: disable=line-too-long +def sparkconf_k8s( + app_name: str, + namespace: str, + executor_image: str, + driver_image: str, + master: str = "k8s://https://kubernetes.default.svc:443", + num_executors: str = "2", + executor_memory: str = "1g", + executor_cores: str = "2", + driver_memory: str = "1g", + driver_cores: str = "2", + pvc_mount_path: Union[str, None] = None, + pvc_claim_name: Union[str, None] = None, + python_path: Union[str, None] = None, + k8s_service_account: Union[str, None] = None, + **kwargs, +): + """Return a spark configuraion object for use on a kubernetes cluster. For more information on + what some of these parameters are for see + https://spark.apache.org/docs/latest/running-on-kubernetes.html + + NOTE: if you are simply running a local spark job, we advise you use the sparkconf_local method + instead as it has fewer parameters and more defaults to get you going more quickly. + + Args: + app_name (str): The application name (useful for for keeping track of jobs on a multiuser + cluster) + namespace (str): k8s namespace in which this job will run (e.g., "default") + executor_image (str): The container image to use for spark executors. + driver_image (str): The spark driver image to use (tpyically the same as exectuor image) + master (_type_, optional): The master specificication. Defaults to + "k8s://https://kubernetes.default.svc:443". + num_executors (str, optional): The number of executors to run. Defaults to "2". + executor_memory (str, optional): The maximum memory allocated to each executor (use g or M + notation). Defaults to "1g". + executor_cores (str, optional): The maximum number of cores per executor. Defaults to "2". + driver_memory (str, optional): The maxumum memory allocated to the driver. Defaults to + "1g". + driver_cores (str, optional): The maximum number of cores allocated to the driver. Defaults + to "2". + pvc_mount_path (str | None, optional): The PVC mount path for exectuors and driver to mount + (this usually has to be rwX). Defaults to None. + pvc_claim_name (str | None, optional): The PVC claim name assocated with the PVC mount. + Defaults to None. + python_path (str | None, optional): The python path to use in python jobs in executor and + driver python processes. Defaults to None. + k8s_service_account (str | None, optional): The k8s service account to use. Defaults to + None. + kwargs: passthru key,value arguments that will be added to the spark configuration + + Returns: + SparkConf: A spark configuration that has been defined in a way that makes it compatible + with time series use cases and intended for use with a k8s cluster. + """ + + if not WE_HAVE_PYSPARK: + return {} + + conf: SparkConf = ( + SparkConf().setAppName(f"{app_name}.caikit.{namespace}").setMaster(master) + ) + + # pushing config out of global configuration file + conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + + # executor/driver spec + conf.set("spark.driver.memory", driver_memory) + conf.set("spark.executor.memory", executor_memory) + conf.set("spark.executor.cores", executor_cores) + conf.set("spark.driver.cores", driver_cores) + conf.set("spark.executor.instances", num_executors) + conf.set("spark.sql.session.timeZone", "UTC") + + # kubernetes specific + if "K8S" in master.upper(): + if python_path: + conf.setExecutorEnv("PYTHONPATH", python_path) + conf.set("spark.kubernetes.namespace", namespace) + conf.set("spark.kubernetes.executor.container.image", executor_image) + conf.set( + "spark.kubernetes.driver.container.image", + driver_image if driver_image else executor_image, + ) + conf.set("spark.kubernetes.driver.annotation.sidecar.istio.io/inject", "false") + conf.set( + "spark.kubernetes.executor.annotation.sidecar.istio.io/inject", "false" + ) + # networking minutia + conf.set("spark.driver.host", socket.gethostbyname(socket.gethostname())) + conf.set("spark.driver.port", "37371") + conf.set("spark.blockManager.port", "6060") + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + # conf.set("spark.kubernetes.authenticate.driver.serviceAccountName", "spark") + + if pvc_mount_path: + conf.set( + f"spark.kubernetes.executor.volumes.persistentVolumeClaim.{pvc_claim_name}.mount.path", + pvc_mount_path, + ) + conf.set( + f"spark.kubernetes.executor.volumes.persistentVolumeClaim.{pvc_claim_name}.mount.readOnly", + "false", + ) + + if pvc_claim_name: + conf.set( + f"spark.kubernetes.executor.volumes.persistentVolumeClaim.{pvc_claim_name}.options.claimName", + pvc_claim_name, + ) + + if k8s_service_account: + conf.set( + "spark.kubernetes.authenticate.driver.serviceAccountName", + k8s_service_account, + ) + + conf.set("spark.kubernetes.container.image.pullPolicy", "Always") + + for param, val in kwargs.items(): + conf.set(param, val) + + return conf diff --git a/caikit/interfaces/ts/tasks.py b/caikit/interfaces/ts/tasks.py new file mode 100644 index 000000000..f99630b95 --- /dev/null +++ b/caikit/interfaces/ts/tasks.py @@ -0,0 +1,54 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +All timeseries tasks used by modules +""" + +# Local +from .data_model.timeseries import TimeSeries +from .data_model.timeseries_evaluation import EvaluationResult +from caikit.core import TaskBase, task + + +@task( + required_parameters={"X": TimeSeries}, + output_type=TimeSeries, +) +class AnomalyDetectionTask(TaskBase): + """Task for all anomaly detection modules""" + + +@task( + required_parameters={"targets": TimeSeries, "predictions": TimeSeries}, + output_type=EvaluationResult, +) +class EvaluationTask(TaskBase): + """Task for all evaluation modules""" + + +@task( + required_parameters={"X": TimeSeries}, + output_type=TimeSeries, +) +class ForecastingTask(TaskBase): + """Task for all forecasting modules""" + + +@task( + required_parameters={"X": TimeSeries}, + output_type=TimeSeries, +) +class TransformersTask(TaskBase): + """Task for all transformer modules""" diff --git a/caikit/runtime/__main__.py b/caikit/runtime/__main__.py index 20b288536..2483e3749 100644 --- a/caikit/runtime/__main__.py +++ b/caikit/runtime/__main__.py @@ -1,6 +1,16 @@ -# Standard -import signal - +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # First Party import alog @@ -14,25 +24,6 @@ def main(): _grpc_server = None _http_server = None - def interrupt(signal_, _stack_frame): - log.info( - "", - "Caikit Runtime received interrupt signal %s, shutting down", - signal_, - ) - if _grpc_server: - _grpc_server.stop() - if _http_server: - _http_server.stop() - - # NOTE: signal function can only be called from main thread of the main - # interpreter. If this function is called from a thread (like in tests) - # then signal handler cannot be used. Thus, we will only have real - # termination_handler when this is called from the __main__. - - signal.signal(signal.SIGINT, interrupt) - signal.signal(signal.SIGTERM, interrupt) - ##################### # Start the servers ##################### diff --git a/caikit/runtime/dump_services.py b/caikit/runtime/dump_services.py index bec7f9f45..0bf376323 100644 --- a/caikit/runtime/dump_services.py +++ b/caikit/runtime/dump_services.py @@ -13,6 +13,7 @@ # limitations under the License. # Standard +import argparse import json import os import sys @@ -29,21 +30,34 @@ log = alog.use_channel("RUNTIME-DUMP-SVC") -def dump_grpc_services(output_dir: str): +def dump_grpc_services(output_dir: str, write_modules_file): """Utility for rendering the all generated interfaces to proto files""" - inf_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.INFERENCE, write_modules_file=True - ) - train_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.TRAINING, - ) - train_mgt_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, + inf_enabled = get_config().runtime.service_generation.enable_inference + train_enabled = get_config().runtime.service_generation.enable_training + + if inf_enabled: + inf_svc = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFERENCE, + write_modules_file=write_modules_file, + ) + if train_enabled: + train_svc = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.TRAINING, + ) + train_mgt_svc = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, + ) + info_svc = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFO, ) + render_dataobject_protos(output_dir) - inf_svc.service.write_proto_file(output_dir) - train_svc.service.write_proto_file(output_dir) - train_mgt_svc.service.write_proto_file(output_dir) + if inf_enabled: + inf_svc.service.write_proto_file(output_dir) + if train_enabled: + train_svc.service.write_proto_file(output_dir) + train_mgt_svc.service.write_proto_file(output_dir) + info_svc.service.write_proto_file(output_dir) def dump_http_services(output_dir: str): @@ -86,13 +100,40 @@ def dump_http_services(output_dir: str): handle.write(json.dumps(response.json(), indent=2)) -if __name__ == "__main__": - assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} " - out_dir = sys.argv[1] +def main(): + parser = argparse.ArgumentParser( + description="Dump grpc and http services for inference and train" + ) + + # Add an argument for the output_dir + parser.add_argument( + "output_dir", + type=str, + help="Path to the output directory for service(s)' proto files", + ) + + # Add an argument for write_modules_json + parser.add_argument( + "-j", + "--write-modules-json", + default=False, + action="store_true", + help="Wether the modules.json (of supported modules) should be output?", + ) + + args = parser.parse_args() + + out_dir = args.output_dir + write_modules_json = args.write_modules_json + # Set up logging so users can set LOG_LEVEL etc caikit.core.toolkit.logging.configure() if get_config().runtime.grpc.enabled: - dump_grpc_services(out_dir) + dump_grpc_services(out_dir, write_modules_json) if get_config().runtime.http.enabled: dump_http_services(out_dir) + + +if __name__ == "__main__": + main() diff --git a/caikit/runtime/grpc_server.py b/caikit/runtime/grpc_server.py index b7b38af9b..d061eb020 100644 --- a/caikit/runtime/grpc_server.py +++ b/caikit/runtime/grpc_server.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Standard from concurrent import futures from typing import Optional, Union @@ -42,6 +41,7 @@ from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer +from caikit.runtime.servicers.info_servicer import InfoServicer from caikit.runtime.servicers.model_runtime_servicer import ModelRuntimeServicerImpl from caikit.runtime.servicers.model_train_servicer import ModelTrainServicerImpl from caikit.runtime.servicers.training_management_servicer import ( @@ -68,9 +68,7 @@ def __init__( # Initialize basic server self.server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=self.config.runtime.grpc.server_thread_pool_size - ), + thread_pool=self.thread_pool, interceptors=(PROMETHEUS_METRICS_INTERCEPTOR,), options=(self.config.runtime.grpc.options or {}).items(), ) @@ -86,7 +84,7 @@ def __init__( if self.enable_inference: log.info("", "Enabling gRPC inference service") self._global_predict_servicer = GlobalPredictServicer( - self.inference_service + self.inference_service, interrupter=self.interrupter ) self.server = CaikitRuntimeServerWrapper( server=self.server, @@ -102,6 +100,7 @@ def __init__( # And intercept a training service, if we have one if self.enable_training and self.training_service: + log.info("", "Enabling gRPC training service") global_train_servicer = GlobalTrainServicer(self.training_service) self.server = CaikitRuntimeServerWrapper( server=self.server, @@ -134,12 +133,22 @@ def __init__( # Add model runtime servicer to the gRPC server model_runtime_pb2_grpc.add_ModelRuntimeServicer_to_server( - ModelRuntimeServicerImpl(), self.server + ModelRuntimeServicerImpl(interrupter=self.interrupter), self.server ) service_names.append( model_runtime_pb2.DESCRIPTOR.services_by_name["ModelRuntime"].full_name ) + # Add runtime info servicer to the gRPC server + runtime_info_service: ServicePackage = ( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + ) + service_names.append(runtime_info_service.descriptor.full_name) + + runtime_info_service.registration_function(InfoServicer(), self.server) + # Add gRPC default health servicer. # We use the non-blocking implementation to avoid thread starvation. health_servicer = health.HealthServicer( @@ -153,7 +162,7 @@ def __init__( # Listen on a unix socket as well for model mesh. if self.config.runtime.grpc.unix_socket_path and os.path.exists( - self.config.runtime.grpc.unix_socket_path + os.path.dirname(self.config.runtime.grpc.unix_socket_path) ): try: self.server.add_insecure_port( @@ -184,11 +193,20 @@ def __init__( ) if self.tls_config.client.cert: log.info("", "Running with mutual TLS") + # Combine the client cert with the server's own cert so that + # health probes can use the server's key/cert instead of needing + # one signed by a potentially-external CA. + root_certificates = b"\n".join( + [ + bytes(self._load_secret(self.tls_config.client.cert), "utf-8"), + tls_server_pair[1], + ] + ) # Client will verify the server using server cert and the server # will verify the client using client cert. server_credentials = grpc.ssl_server_credentials( [tls_server_pair], - root_certificates=self._load_secret(self.tls_config.client.cert), + root_certificates=root_certificates, require_client_auth=True, ) else: @@ -205,14 +223,18 @@ def start(self, blocking: bool = True): Args: blocking (boolean): Whether to block until shutdown """ + # Boot the thread interrupter + if self.interrupter: + self.interrupter.start() + # Start the server. This is non-blocking, so we need to wait after self.server.start() log.info( "", - "Caikit Runtime is serving on port: %s with thread pool size: %s", + "Caikit Runtime is serving grpc on port: %s with thread pool size: %s", self.port, - self.config.runtime.grpc.server_thread_pool_size, + self.thread_pool._max_workers, ) if blocking: @@ -225,16 +247,21 @@ def stop(self, grace_period_seconds: Union[float, int] = None): grace_period_seconds (Union[float, int]): Grace period for service shutdown. Defaults to application config """ + log.info("Shutting down gRPC server") if grace_period_seconds is None: grace_period_seconds = ( self.config.runtime.grpc.server_shutdown_grace_period_seconds ) + log.debug4("Stopping grpc server with %s grace seconds", grace_period_seconds) self.server.stop(grace_period_seconds) # Ensure we flush out any remaining billing metrics and stop metering if self.config.runtime.metering.enabled and self._global_predict_servicer: self._global_predict_servicer.stop_metering() # Shut down the model manager's model polling if enabled self._shut_down_model_manager() + # Shut down the thread interrupter + if self.interrupter: + self.interrupter.stop() def render_protos(self, proto_out_dir: str) -> None: """Renders all the necessary protos for this service into a directory @@ -262,7 +289,7 @@ def _load_secret(secret: str) -> str: """If the secret points to a file, return the contents (plaintext reads). Else return the string""" if os.path.exists(secret): - with open(secret, "r", encoding="utf-8") as secret_file: + with open(secret, encoding="utf-8") as secret_file: return secret_file.read() return secret @@ -277,7 +304,6 @@ def __exit__(self, type_, value, traceback): def main(blocking: bool = True): server = RuntimeGRPCServer() - server._intercept_interrupt_signal() server.start(blocking) diff --git a/caikit/runtime/http_server/__init__.py b/caikit/runtime/http_server/__init__.py index 30d530c38..7f07b87fd 100644 --- a/caikit/runtime/http_server/__init__.py +++ b/caikit/runtime/http_server/__init__.py @@ -13,5 +13,10 @@ # limitations under the License. # Local -from .http_server import HEALTH_ENDPOINT, RuntimeHTTPServer +from .http_server import ( + HEALTH_ENDPOINT, + MODELS_INFO_ENDPOINT, + RUNTIME_INFO_ENDPOINT, + RuntimeHTTPServer, +) from .pydantic_wrapper import dataobject_to_pydantic, pydantic_to_dataobject diff --git a/caikit/runtime/http_server/http_server.py b/caikit/runtime/http_server/http_server.py index fb5bd1de8..b008d9158 100644 --- a/caikit/runtime/http_server/http_server.py +++ b/caikit/runtime/http_server/http_server.py @@ -17,24 +17,27 @@ API based on the task definitions available at boot. """ # Standard -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import partial -from typing import Any, Dict, Iterable, Optional, Type, Union, get_args +from typing import Any, Dict, Iterable, List, Optional, Type, Union, get_args import asyncio +import inspect import io import json import os -import re +import signal import ssl import tempfile import threading import time +import traceback +import uuid # Third Party -from fastapi import FastAPI, HTTPException, Request, Response, status +from fastapi import FastAPI, HTTPException, Query, Request, Response, status from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import ResponseValidationError +from fastapi.exceptions import RequestValidationError, ResponseValidationError from fastapi.responses import JSONResponse, PlainTextResponse from grpc import StatusCode from sse_starlette import EventSourceResponse, ServerSentEvent @@ -42,7 +45,10 @@ import uvicorn # First Party -from py_to_proto.dataclass_to_proto import get_origin # Imported here for 3.8 compat +from py_to_proto.dataclass_to_proto import ( # Imported here for 3.8 compat + Annotated, + get_origin, +) import aconfig import alog @@ -52,11 +58,27 @@ pydantic_from_request, pydantic_to_dataobject, ) +from .request_aborter import HttpRequestAborter from .utils import convert_json_schema_to_multipart, flatten_json_schema from caikit.config import get_config from caikit.core.data_model import DataBase from caikit.core.data_model.dataobject import make_dataobject +from caikit.core.exceptions import error_handler +from caikit.core.exceptions.caikit_core_exception import ( + CaikitCoreException, + CaikitCoreStatusCode, +) from caikit.core.toolkit.sync_to_async import async_wrap_iter +from caikit.runtime.names import ( + HEALTH_ENDPOINT, + MODEL_ID, + MODELS_INFO_ENDPOINT, + OPTIONAL_INPUTS_KEY, + REQUIRED_INPUTS_KEY, + RUNTIME_INFO_ENDPOINT, + StreamEventTypes, + get_http_route_name, +) from caikit.runtime.server_base import RuntimeServerBase from caikit.runtime.service_factory import ServicePackage from caikit.runtime.service_generation.rpcs import ( @@ -66,17 +88,19 @@ ) from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer +from caikit.runtime.servicers.info_servicer import InfoServicer from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException ## Globals ##################################################################### log = alog.use_channel("SERVR-HTTP") +error = error_handler.get(log) -# Mapping from GRPC codes to their corresponding HTTP codes -# pylint: disable=line-too-long -# CITE: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md -GRPC_CODE_TO_HTTP = { +STATUS_CODE_TO_HTTP = { + # Mapping from GRPC codes to their corresponding HTTP codes + # pylint: disable=line-too-long + # CITE: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md StatusCode.OK: 200, StatusCode.INVALID_ARGUMENT: 400, StatusCode.FAILED_PRECONDITION: 400, @@ -93,18 +117,17 @@ StatusCode.UNIMPLEMENTED: 501, StatusCode.UNAVAILABLE: 501, StatusCode.DEADLINE_EXCEEDED: 504, + # Mapping from CaikitCore StatusCodes codes to their corresponding HTTP codes + CaikitCoreStatusCode.INVALID_ARGUMENT: 400, + CaikitCoreStatusCode.UNAUTHORIZED: 401, + CaikitCoreStatusCode.FORBIDDEN: 403, + CaikitCoreStatusCode.NOT_FOUND: 404, + CaikitCoreStatusCode.CONNECTION_ERROR: 500, + CaikitCoreStatusCode.UNKNOWN: 500, + CaikitCoreStatusCode.FATAL: 500, } -# These keys are used to define the logical sections of the request and response -# data structures. -REQUIRED_INPUTS_KEY = "inputs" -OPTIONAL_INPUTS_KEY = "parameters" -MODEL_ID = "model_id" - -# Endpoint to use for health checks -HEALTH_ENDPOINT = "/health" - # Small dataclass for consolidating TLS files @dataclass class _TlsFiles: @@ -128,6 +151,26 @@ def __init__(self, tls_config_override: Optional[aconfig.Config] = None): self.app = FastAPI() + # Request validation + @self.app.exception_handler(RequestValidationError) + async def request_validation_exception_handler( + _, exc: RequestValidationError + ) -> Response: + err_code = status.HTTP_422_UNPROCESSABLE_ENTITY + error_content = { + "details": exc.errors()[0]["msg"] + if len(exc.errors()) > 0 and "msg" in exc.errors()[0] + else exc.errors(), + "additional_info": exc.errors(), + "code": err_code, + "id": uuid.uuid4().hex, + } + log.error("", error_content, exc_info=True) + return JSONResponse( + content=jsonable_encoder(error_content), + status_code=err_code, + ) + # Response validation @self.app.exception_handler(ResponseValidationError) async def validation_exception_handler(_, exc: ResponseValidationError): @@ -142,11 +185,14 @@ async def validation_exception_handler(_, exc: ResponseValidationError): # Placeholders for global servicers self.global_predict_servicer = None self.global_train_servicer = None + self.info_servicer = InfoServicer() # Set up inference if enabled if self.enable_inference: log.info("", "Enabling HTTP inference service") - self.global_predict_servicer = GlobalPredictServicer(self.inference_service) + self.global_predict_servicer = GlobalPredictServicer( + self.inference_service, interrupter=self.interrupter + ) self._bind_routes(self.inference_service) # Set up training if enabled @@ -160,6 +206,15 @@ async def validation_exception_handler(_, exc: ResponseValidationError): self._health_check ) + # Add runtime info endpoints + self.app.get(RUNTIME_INFO_ENDPOINT, response_class=JSONResponse)( + self.info_servicer.get_version_dict + ) + + self.app.get(MODELS_INFO_ENDPOINT, response_class=JSONResponse)( + self._model_info + ) + # Parse TLS configuration # If any of the TLS values are not files, we assume that they're inline # content. The python SslContext only takes files to load, so we use a @@ -178,35 +233,46 @@ async def validation_exception_handler(_, exc: ResponseValidationError): log.info("", "Running INSECURE") # Start the server with a timeout_graceful_shutdown - # if not set in config, this is None and unvicorn accepts None or number of seconds + # if not set in config, this is None and unvicorn accepts None or + # number of seconds unvicorn_timeout_graceful_shutdown = ( get_config().runtime.http.server_shutdown_grace_period_seconds ) + server_config = get_config().runtime.http.server_config + overlapping_tls_config = set(tls_kwargs).intersection(server_config) + error.value_check( + "", + not overlapping_tls_config, + "Found overlapping config keys between TLS and server_config: %s", + overlapping_tls_config, + ) + config_kwargs = { + "host": "0.0.0.0", + "port": self.port, + "log_level": None, + "log_config": None, + "timeout_graceful_shutdown": unvicorn_timeout_graceful_shutdown, + } + overlapping_kwarg_config = set(config_kwargs).intersection(server_config) + error.value_check( + "", + not overlapping_kwarg_config, + "Found caikit-managed uvicorn config in server_config: %s", + overlapping_kwarg_config, + ) config = uvicorn.Config( self.app, - host="0.0.0.0", - port=self.port, - log_level=None, - log_config=None, - timeout_graceful_shutdown=unvicorn_timeout_graceful_shutdown, + **config_kwargs, **tls_kwargs, + **server_config, ) # Make sure the config loads TLS files here so they can safely be # deleted if they're ephemeral config.load() - # Start the server with the loaded config + # Build the server with the loaded config self.server = uvicorn.Server(config=config) - # Patch the exit handler to call this server's stop - original_handler = self.server.handle_exit - - def shutdown_wrapper(*args, **kwargs): - original_handler(*args, **kwargs) - self.stop() - - self.server.handle_exit = shutdown_wrapper - # Placeholder for thread when running without blocking self._uvicorn_server_thread = None @@ -221,10 +287,23 @@ def start(self, blocking: bool = True): Args: blocking (boolean): Whether to block until shutdown """ + log.info( + "", + "Caikit Runtime is serving http on port: %s with thread pool size: %s", + self.port, + self.thread_pool._max_workers, + ) + + if self.interrupter: + self.interrupter.start() + + # Patch the exit handler to retain correct signal handling behavior + self._patch_exit_handler() + if blocking: self.server.run() else: - self.run_in_thread() + self._run_in_thread() def stop(self): """Stop the server, with an optional grace period. @@ -233,11 +312,14 @@ def stop(self): grace_period_seconds (Union[float, int]): Grace period for service shutdown. Defaults to application config """ - self.server.should_exit = True + log.info("Shutting down http server") + if ( self._uvicorn_server_thread is not None and self._uvicorn_server_thread.is_alive() ): + # This is required to notify the server in the thread to exit + self.server.should_exit = True self._uvicorn_server_thread.join() # Ensure we flush out any remaining billing metrics and stop metering @@ -247,16 +329,20 @@ def stop(self): # Shut down the model manager's model polling if enabled self._shut_down_model_manager() - def run_in_thread(self): + if self.interrupter: + self.interrupter.stop() + + ########## + ## Impl ## + ########## + + def _run_in_thread(self): self._uvicorn_server_thread = threading.Thread(target=self.server.run) self._uvicorn_server_thread.start() while not self.server.started: time.sleep(1e-3) log.info("HTTP Server is running in thread") - ########## - ## Impl ## - ########## def _bind_routes(self, service: ServicePackage): """Bind all rpcs as routes to the given app""" for rpc in service.caikit_rpcs.values(): @@ -265,7 +351,8 @@ def _bind_routes(self, service: ServicePackage): if hasattr(rpc, "input_streaming") and rpc.input_streaming: # Skipping the binding of this route since we don't have support log.info( - "No support for input streaming on REST Server yet! Skipping this rpc %s with input type %s", + "No support for input streaming on REST Server yet!" + "Skipping this rpc %s with input type %s", rpc_info["name"], rpc_info["input_type"], ) @@ -277,7 +364,7 @@ def _bind_routes(self, service: ServicePackage): elif isinstance(rpc, ModuleClassTrainRPC): self._train_add_unary_input_unary_output_handler(rpc) - def _get_model_id(self, request: Type[pydantic.BaseModel]) -> Dict[str, Any]: + def _get_model_id(self, request: Type[pydantic.BaseModel]) -> str: """Get the model id from the payload""" request_kwargs = dict(request) model_id = request_kwargs.get(MODEL_ID, None) @@ -330,7 +417,7 @@ def _train_add_unary_input_unary_output_handler(self, rpc: CaikitRPCBase): pydantic_response = dataobject_to_pydantic(response_data_object) @self.app.post( - self._get_route(rpc), + get_http_route_name(rpc.name), responses=self._get_response_openapi( response_data_object, pydantic_response ), @@ -342,41 +429,44 @@ async def _handler(context: Request) -> Response: log.debug("In unary handler for %s", rpc.name) loop = asyncio.get_running_loop() - # build request DM object - request = await pydantic_from_request(pydantic_request, context) - http_request_dm_object = pydantic_to_dataobject(request) - try: + # build request DM object + request = await pydantic_from_request(pydantic_request, context) + http_request_dm_object = pydantic_to_dataobject(request) + call = partial( self.global_train_servicer.run_training_job, request=http_request_dm_object.to_proto(), module=rpc.clz, - training_output_dir=None, # pass None so that GTS picks up the config one # TODO: double-check? + training_output_dir=None, # pass None so that GTS picks up the config one # TODO: double-check? # noqa: E501 # context=context, - wait=True, + wait=False, ) result = await loop.run_in_executor(None, call) if response_data_object.supports_file_operations: return self._format_file_response(result) return Response(content=result.to_json(), media_type="application/json") + except RequestValidationError as err: + raise err except HTTPException as err: raise err - except CaikitRuntimeException as err: - error_code = GRPC_CODE_TO_HTTP.get(err.status_code, 500) + except (CaikitCoreException, CaikitRuntimeException) as err: + error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) error_content = { "details": err.message, "code": error_code, "id": err.id, } + log.error("", error_content, exc_info=True) except Exception as err: # pylint: disable=broad-exception-caught error_code = 500 error_content = { "details": f"Unhandled exception: {str(err)}", "code": error_code, - "id": None, + "id": uuid.uuid4().hex, } - log.error("", err, exc_info=True) + log.error("", error_content, exc_info=True) return Response( content=json.dumps(error_content), status_code=error_code ) # pylint: disable=used-before-assignment @@ -388,7 +478,7 @@ def _add_unary_input_unary_output_handler(self, rpc: TaskPredictRPC): pydantic_response = dataobject_to_pydantic(response_data_object) @self.app.post( - self._get_route(rpc), + get_http_route_name(rpc.name), responses=self._get_response_openapi( response_data_object, pydantic_response ), @@ -399,11 +489,10 @@ def _add_unary_input_unary_output_handler(self, rpc: TaskPredictRPC): async def _handler( context: Request, ) -> Response: - - request = await pydantic_from_request(pydantic_request, context) - request_params = self._get_request_params(rpc, request) - try: + request = await pydantic_from_request(pydantic_request, context) + request_params = self._get_request_params(rpc, request) + model_id = self._get_model_id(request) log.debug4( "Sending request %s to model id %s", request_params, model_id @@ -417,22 +506,25 @@ async def _handler( log.debug4( "Sending request %s to model id %s", request_params, model_id ) - model = self.global_predict_servicer._model_manager.retrieve_model( - model_id - ) - # TODO: use `async_wrap_*`? - call = partial( - self.global_predict_servicer.predict_model, - model_id=model_id, - request_name=rpc.request.name, - inference_func_name=model.get_inference_signature( - output_streaming=False, input_streaming=False, task=rpc.task - ).method_name, - **request_params, + aborter_context = ( + HttpRequestAborter(context) if self.interrupter else nullcontext() ) - result = await loop.run_in_executor(None, call) - log.debug4("Response from model %s is %s", model_id, result) + + with aborter_context as aborter: + # TODO: use `async_wrap_*`? + call = partial( + self.global_predict_servicer.predict_model, + model_id=model_id, + request_name=rpc.request.name, + input_streaming=False, + output_streaming=False, + task=rpc.task, + aborter=aborter, + **request_params, + ) + result = await loop.run_in_executor(self.thread_pool, call) + log.debug4("Response from model %s is %s", model_id, result) if response_data_object.supports_file_operations: return self._format_file_response(result) @@ -441,21 +533,24 @@ async def _handler( except HTTPException as err: raise err - except CaikitRuntimeException as err: - error_code = GRPC_CODE_TO_HTTP.get(err.status_code, 500) + except RequestValidationError as err: + raise err + except (CaikitCoreException, CaikitRuntimeException) as err: + error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) error_content = { "details": err.message, "code": error_code, "id": err.id, } + log.error("", error_content, exc_info=True) except Exception as err: # pylint: disable=broad-exception-caught error_code = 500 error_content = { "details": f"Unhandled exception: {str(err)}", "code": error_code, - "id": None, + "id": uuid.uuid4().hex, } - log.error("", err, exc_info=True) + log.error("", error_content, exc_info=True) return Response( content=json.dumps(error_content), status_code=error_code ) # pylint: disable=used-before-assignment @@ -466,7 +561,7 @@ def _add_unary_input_stream_output_handler(self, rpc: CaikitRPCBase): # pylint: disable=unused-argument @self.app.post( - self._get_route(rpc), + get_http_route_name(rpc.name), response_model=pydantic_response, openapi_extra=self._get_request_openapi(pydantic_request), ) @@ -482,65 +577,73 @@ async def _generator() -> pydantic_response: log.debug4( "Sending request %s to model id %s", request_params, model_id ) - model = self.global_predict_servicer._model_manager.retrieve_model( - model_id + + aborter_context = ( + HttpRequestAborter(context) + if self.interrupter + else nullcontext() ) - log.debug("In stream generator for %s", rpc.name) - async for result in async_wrap_iter( - self.global_predict_servicer.predict_model( - model_id=model_id, - request_name=rpc.request.name, - inference_func_name=model.get_inference_signature( - output_streaming=True, input_streaming=False - ).method_name, - **request_params, - ) - ): - yield result + with aborter_context as aborter: + log.debug("In stream generator for %s", rpc.name) + async for result in async_wrap_iter( + self.global_predict_servicer.predict_model( + model_id=model_id, + request_name=rpc.request.name, + input_streaming=False, + output_streaming=True, + task=rpc.task, + aborter=aborter, + **request_params, + ), + pool=self.thread_pool, + ): + yield ServerSentEvent( + data=result.to_json(), + event=StreamEventTypes.MESSAGE.value, + ) + return except HTTPException as err: raise err - except CaikitRuntimeException as err: - error_code = GRPC_CODE_TO_HTTP.get(err.status_code, 500) + except RequestValidationError as err: + raise err + except (TypeError, ValueError) as err: + log_dict = { + "log_code": "", + "message": repr(err), + "stack_trace": traceback.format_exc(), + } + log.warning(log_dict) + error_code = 400 + error_content = { + "details": repr(err), + "code": error_code, + } + except (CaikitCoreException, CaikitRuntimeException) as err: + error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) error_content = { "details": err.message, "code": error_code, "id": err.id, } + log.error("", error_content, exc_info=True) except Exception as err: # pylint: disable=broad-exception-caught error_code = 500 error_content = { "details": f"Unhandled exception: {str(err)}", "code": error_code, - "id": None, + "id": uuid.uuid4().hex, } - log.error("", err, exc_info=True) + log.error("", error_content, exc_info=True) # If an error occurs, yield an error response and terminate - yield ServerSentEvent(data=json.dumps(error_content)) + yield ServerSentEvent( + data=json.dumps(error_content), event=StreamEventTypes.ERROR.value + ) return EventSourceResponse(_generator()) - def _get_route(self, rpc: CaikitRPCBase) -> str: - """Get the REST route for this rpc""" - if rpc.name.endswith("Predict"): - task_name = re.sub( - r"(? Type[DataBase]: """Get the dataobject request for the given rpc""" is_inference_rpc = hasattr(rpc, "task") @@ -635,19 +738,26 @@ def _format_file_response(dm_class: Type[DataBase]) -> Response: @staticmethod def _get_request_openapi( - pydantic_model: Union[pydantic.BaseModel, Type[pydantic.BaseModel]] + pydantic_model: Union[pydantic.BaseModel, Type, Type[pydantic.BaseModel]] ): """Helper to generate the openapi schema for a given request""" - raw_schema = pydantic_model.model_json_schema() - parsed_schema = flatten_json_schema(raw_schema) + # Get the json schema from the pydantic model or TypeAdapter + if inspect.isclass(pydantic_model) and issubclass( + pydantic_model, pydantic.BaseModel + ): + raw_schema = pydantic_model.model_json_schema() + else: + raw_schema = pydantic.TypeAdapter(pydantic_model).json_schema() + + parsed_schema = flatten_json_schema(raw_schema) multipart_schema = convert_json_schema_to_multipart(parsed_schema) return { "requestBody": { "content": { - "application/json": {"schema": parsed_schema}, "multipart/form-data": {"schema": multipart_schema}, + "application/json": {"schema": parsed_schema}, }, "required": True, } @@ -655,7 +765,7 @@ def _get_request_openapi( @staticmethod def _get_response_openapi( - dm_class: Type[DataBase], pydantic_model: Type[pydantic.BaseModel] + dm_class: Type[DataBase], pydantic_model: Union[Type, Type[pydantic.BaseModel]] ): """Helper to generate the openapi schema for a given response""" @@ -664,15 +774,46 @@ def _get_response_openapi( "application/octet-stream": {"type": "string", "format": "binary"} } else: - response_schema = { - "application/json": flatten_json_schema( - pydantic_model.model_json_schema() - ) - } + # Get the json schema from the pydantic model or TypeAdapter + if inspect.isclass(pydantic_model) and issubclass( + pydantic_model, pydantic.BaseModel + ): + json_schema = pydantic_model.model_json_schema() + else: + json_schema = pydantic.TypeAdapter(pydantic_model).json_schema() + + response_schema = {"application/json": flatten_json_schema(json_schema)} output = {200: {"content": response_schema}} return output + def _model_info( + self, model_ids: Annotated[List[str], Query(default_factory=list)] + ) -> Dict[str, Any]: + """Create wrapper for get_models_info so model_ids can be marked as a query parameter""" + try: + return self.info_servicer.get_models_info_dict(model_ids) + except HTTPException as err: + raise err + except CaikitRuntimeException as err: + error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) + error_content = { + "details": err.message, + "code": error_code, + "id": err.id, + } + log.error("", error_content, exc_info=True) + return error_content + except Exception as err: # pylint: disable=broad-exception-caught + error_code = 500 + error_content = { + "details": f"Unhandled exception: {str(err)}", + "code": error_code, + "id": uuid.uuid4().hex, + } + log.error("", error_content, exc_info=True) + return error_content + @staticmethod def _health_check() -> str: log.debug4("Server healthy") @@ -717,15 +858,39 @@ def _tls_files(self) -> _TlsFiles: except OSError as err: log.error( "", - "Cannot create temporary TLS files. Either pass config as file paths or run with write permissions.", + ( + "Cannot create temporary TLS files." + "Either pass config as file paths or run with write permissions." + ), exc_info=True, ) raise ValueError() from err + def _patch_exit_handler(self): + """ + 🌶️🌶️🌶️ Here there are dragons! 🌶️🌶️🌶️ + uvicorn will explicitly set the interrupt handler to `server.handle_exit` when + `server.run()` is called. That will override any other signal handlers that we + may have tried to set. + + To work around this, we: + 1. Register `server.handle_exit` as a SIGINT/SIGTERM signal handler ourselves, so that it + is invoked on interrupt and terminate + 2. Set `server.handle_exit` to the existing SIGINT signal handler, so that when the uvicorn + server explicitly overrides the signal handler for SIGINT and SIGTERM to this, it has no + effect. + + Since uvicorn overrides SIGINT and SIGTERM with a single common handler, any special + handlers added for SIGTERM but not SIGINT will not be invoked. + """ + original_exit_handler = self.server.handle_exit + self._add_signal_handler(signal.SIGINT, original_exit_handler) + self._add_signal_handler(signal.SIGTERM, original_exit_handler) + self.server.handle_exit = signal.getsignal(signal.SIGINT) + def main(blocking: bool = True): server = RuntimeHTTPServer() - server._intercept_interrupt_signal() server.start(blocking) diff --git a/caikit/runtime/http_server/pydantic_wrapper.py b/caikit/runtime/http_server/pydantic_wrapper.py index dd6a5cc87..0542cbc80 100644 --- a/caikit/runtime/http_server/pydantic_wrapper.py +++ b/caikit/runtime/http_server/pydantic_wrapper.py @@ -223,8 +223,7 @@ def _parse_form_data_to_pydantic( # Parse each form_data key into a python dict which is then # converted to a pydantic model via .model_validate() raw_model_obj = {} - for key in form_data.keys(): - + for key in form_data: # Get the list of objects that has the key # field name raw_objects = form_data.getlist(key) @@ -284,7 +283,7 @@ def _parse_form_data_to_pydantic( try: raw_objects[n] = json.loads(sub_obj) except TypeError: - raise HTTPException( # pylint: disable=raise-missing-from + raise HTTPException( # noqa: B904 # pylint: disable=raise-missing-from status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Unable to update object at key '{key};" "; expected value to be string", diff --git a/caikit/runtime/http_server/request_aborter.py b/caikit/runtime/http_server/request_aborter.py new file mode 100644 index 000000000..c6cdc50dc --- /dev/null +++ b/caikit/runtime/http_server/request_aborter.py @@ -0,0 +1,119 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module helps us know when a HTTP client call is cancelled, and we need to stop or undo work +""" +# Standard +from typing import Optional +import asyncio +import threading + +# Third Party +from fastapi import Request + +# First Party +import alog + +# Local +from caikit.runtime.work_management.abortable_context import ( + AbortableContext, + ActionAborter, +) + +log = alog.use_channel("REQUEST-ABORTER") + + +class HttpRequestAborter(ActionAborter): + """ + In order to actually interrupt threads doing the work, abortable contexts can be registered + with an instance of this class in order to receive notification on request disconnection. + This allows work to be terminated when a client time's out or stops listening. + + IFF the client request has been terminated, `must_abort` will return True. + """ + + def __init__( + self, + context: Request, + loop: Optional[asyncio.AbstractEventLoop] = None, + poll_time: Optional[float] = 0.25, + ): + """Initialize a Aborter and start the watch loop + + Args: + context: starlette.Request + The HTTP Request to be watched + loop: Optional[asyncio.AbstractEventLoop] + The asyncio loop to run tasks on. If not provided use the running loop + poll_time: Optional[int] + The time between disconnect checks + """ + + self.context = context + self.event_loop = loop or asyncio.get_running_loop() + self.poll_time = poll_time + + # Set initial + self.is_terminated = threading.Event() + self.abortable_context = None + + # Start request aborter task. Hold onto a reference of the task to ensure + # it isn't garbage collected + log.debug("", "Watching for request disconnect") + self.task = self.event_loop.create_task(self.watch_for_disconnect()) + + def __enter__(self): + """Helper function to enable context manager support""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Automatically abort aborter when exiting contextmanager""" + self.abort() + + async def watch_for_disconnect(self): + """Wait for a context to be disconnected""" + + while True: + # Short circuit incase thread terminated externally + if self.is_terminated.is_set(): + log.debug3( + "", "RequestAborter has already been terminated" + ) + return + + is_disconnected = await self.context.is_disconnected() + if is_disconnected: + log.debug("", "Client disconnected, terminating action") + self.is_terminated.set() + if self.abortable_context: + self.abortable_context.abort() + return + log.debug4("", "Client still connected, sleeping aborter") + await asyncio.sleep(self.poll_time) + + def abort(self): + """Helper function to stop aborter before the request has terminated""" + self.is_terminated.set() + + def must_abort(self): + return self.is_terminated.is_set() + + def set_context(self, context: AbortableContext): + self.abortable_context = context + if self.must_abort(): + self.abortable_context.abort() + + def unset_context(self): + self.abortable_context = None diff --git a/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py b/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py index bd84ff028..951aafcd2 100644 --- a/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py +++ b/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py @@ -66,10 +66,7 @@ def __init__(self, server, global_predict, intercepted_svc_package: ServicePacka # concatenate it with the intercepted service name to produce # a fully qualified RPC method name that we wish to intercept # (e.g., '/natural_language_understanding.CaikitRuntime/SyntaxIzumoPredict') - fqm = "/%s/%s" % ( - self._intercepted_svc_package.descriptor.full_name, - method.name, - ) + fqm = f"/{self._intercepted_svc_package.descriptor.full_name}/{method.name}" log.info("", "Intercepting RPC method %s", fqm) self._intercepted_methods.append((method.name, fqm)) diff --git a/caikit/runtime/model_management/batcher.py b/caikit/runtime/model_management/batcher.py index 8198d7d3f..9fe950316 100644 --- a/caikit/runtime/model_management/batcher.py +++ b/caikit/runtime/model_management/batcher.py @@ -258,12 +258,12 @@ def _batch_thread_run(self): # pylint: disable=consider-iterating-dictionary new_kwargs = [ kwarg_name - for kwarg_name in req_kwargs.keys() + for kwarg_name in req_kwargs if kwarg_name not in batch_kwargs ] missing_kwargs = [ kwarg_name - for kwarg_name in batch_kwargs.keys() + for kwarg_name in batch_kwargs if kwarg_name not in req_kwargs ] @@ -359,9 +359,10 @@ def _batch_thread_run(self): log.debug4(batch_kwargs) batch_res = self._model.run_batch(**batch_kwargs) # pylint: disable=line-too-long - assert len(batch_res) == len( - current_batch - ), f"Got result of size [{len(batch_res)}] for batch of size [{len(current_batch)}]" + assert len(batch_res) == len(current_batch), ( + f"Got result of size [{len(batch_res)}] for batch" + "of size [{len(current_batch)}]" + ) for i, (req_id, event, _) in enumerate(current_batch): self._finished_tasks[req_id] = batch_res[i] event.set() diff --git a/caikit/runtime/model_management/model_loader.py b/caikit/runtime/model_management/model_loader.py index 2a3c39960..1a1ce043d 100644 --- a/caikit/runtime/model_management/model_loader.py +++ b/caikit/runtime/model_management/model_loader.py @@ -26,13 +26,10 @@ # Local from caikit.config import get_config from caikit.core import MODEL_MANAGER, ModuleBase +from caikit.core.model_management import ModelFinderBase, ModelInitializerBase from caikit.runtime.model_management.batcher import Batcher from caikit.runtime.model_management.loaded_model import LoadedModel from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException -from caikit.runtime.work_management.abortable_action import ( - AbortableAction, - ActionAborter, -) log = alog.use_channel("MODEL-LOADER") @@ -63,9 +60,10 @@ def load_model( model_id: str, local_model_path: str, model_type: str, - aborter: Optional[ActionAborter] = None, fail_callback: Optional[Callable] = None, retries: int = 0, + finder: Optional[Union[str, ModelFinderBase]] = None, + initializer: Optional[Union[str, ModelInitializerBase]] = None, ) -> LoadedModel: """Start loading a model from disk and associate the ID/size with it @@ -73,8 +71,6 @@ def load_model( model_id (str): Model ID string for the model to load. local_model_path (str): Local filesystem path to load the model from. model_type (str): Type of the model to load. - aborter (Optional[ActionAborter]): An aborter to use that will allow - the call's parent to abort the load fail_callback (Optional[Callable]): Optional no-arg callback to call on load failure retries (int): Number of times to retry loading @@ -92,30 +88,37 @@ def load_model( ) # Set up the async loading - args = (local_model_path, model_id, model_type) + args = (local_model_path, model_id, model_type, finder, initializer) log.debug2("Loading model %s async", model_id) - if aborter is not None: - log.debug3("Using abortable action to load %s", model_id) - action = AbortableAction(aborter, self._load_module, *args) - future_factory = partial(self._load_thread_pool.submit, action.do) - else: - future_factory = partial( - self._load_thread_pool.submit, self._load_module, *args - ) + future_factory = partial( + self._load_thread_pool.submit, self._load_module, *args + ) model_builder.model_future_factory(future_factory) # Return the built model with the future handle return model_builder.build() def _load_module( - self, model_path: str, model_id: str, model_type: str + self, + model_path: str, + model_id: str, + model_type: str, + finder: Optional[Union[str, ModelFinderBase]] = None, + initializer: Optional[Union[str, ModelInitializerBase]] = None, ) -> LoadedModel: try: log.info("", "Loading model '%s'", model_id) + # Only pass finder/initializer if they have values + load_kwargs = {} + if finder: + load_kwargs["finder"] = finder + if initializer: + load_kwargs["initializer"] = initializer + # Load using the caikit.core with CAIKIT_CORE_LOAD_DURATION_SUMMARY.labels(model_type=model_type).time(): - model = MODEL_MANAGER.load(model_path) + model = MODEL_MANAGER.load(model_path, **load_kwargs) # If this model needs batching, configure a Batcher to wrap it model = self._wrap_in_batcher_if_configured( @@ -126,8 +129,7 @@ def _load_module( except FileNotFoundError as fnfe: log_dict = { "log_code": "", - "message": "load failed to find model: %s with error: %s" - % (model_path, repr(fnfe)), + "message": f"load failed to find model: {model_path} with error: {repr(fnfe)}", "model_id": model_id, } log.error(log_dict) @@ -135,11 +137,21 @@ def _load_module( StatusCode.NOT_FOUND, f"Model {model_id} not found. Nested error: {fnfe}", ) from fnfe + except ValueError as ve: + log_dict = { + "log_code": "", + "message": f"load failed to find model: {model_path} with error: {repr(ve)}", + "model_id": model_id, + } + log.error(log_dict) + raise CaikitRuntimeException( + StatusCode.NOT_FOUND, + f"Model {model_id} not found. Nested error: {ve}", + ) from ve except Exception as ex: log_dict = { "log_code": "", - "message": "load failed when processing path: %s with error: %s" - % (model_path, repr(ex)), + "message": f"load failed when processing path: {model_path} with error: {repr(ex)}", "model_id": model_id, } log.error(log_dict, exc_info=True) diff --git a/caikit/runtime/model_management/model_manager.py b/caikit/runtime/model_management/model_manager.py index c3d0c46b7..a29b20f7c 100644 --- a/caikit/runtime/model_management/model_manager.py +++ b/caikit/runtime/model_management/model_manager.py @@ -14,11 +14,13 @@ # Standard from collections import Counter as DictCounter from functools import partial -from typing import Dict, Optional +from pathlib import Path +from typing import Dict, Optional, Union import atexit import gc import os import threading +import time # Third Party from grpc import StatusCode @@ -31,11 +33,11 @@ from caikit import get_config from caikit.core import ModuleBase from caikit.core.exceptions import error_handler +from caikit.core.model_management import ModelFinderBase, ModelInitializerBase from caikit.runtime.model_management.loaded_model import LoadedModel from caikit.runtime.model_management.model_loader import ModelLoader from caikit.runtime.model_management.model_sizer import ModelSizer from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException -from caikit.runtime.work_management.abortable_action import ActionAborter log = alog.use_channel("MODEL-MANAGR") error = error_handler.get(log) @@ -110,11 +112,26 @@ def __init__(self): # Keep track of whether lazy loading is enabled self._lazy_load_local_models = runtime_cfg.lazy_load_local_models - error.value_check( - "", - not self._lazy_load_local_models or self._local_models_dir, - "Must set runtime.local_models_dir with runtime.lazy_load_local_models", - ) + + if self._lazy_load_local_models: + error.value_check( + "", + runtime_cfg.local_models_dir is not None, + ( + "runtime.local_models_dir must be set" + " if using runtime.lazy_load_local_models. " + ), + ) + + error.value_check( + "", + self._local_models_dir, + ( + "runtime.local_models_dir must be a valid path" + " if set with runtime.lazy_load_local_models. " + f"Provided path: {runtime_cfg.local_models_dir}" + ), + ) # Set up local model periodic sync self._lazy_load_poll_period_seconds = runtime_cfg.lazy_load_poll_period_seconds @@ -131,13 +148,28 @@ def __init__(self): and self._lazy_load_local_models and self._lazy_load_poll_period_seconds ) + self._lazy_load_write_detection_period_seconds = ( + runtime_cfg.lazy_load_write_detection_period_seconds + ) + error.type_check( + "", + int, + float, + allow_none=True, + lazy_load_write_detection_period_seconds=self._lazy_load_write_detection_period_seconds, + ) if self._enable_lazy_load_poll: atexit.register(self.shut_down) # Do the initial local models load if self._local_models_dir: - log.info("", "Loading local models into Caikit Runtime...") - self.sync_local_models(wait=True) + wait = runtime_cfg.wait_for_initial_model_loads + log.info( + "", + "Loading local models into Caikit Runtime. Wait: %s", + wait, + ) + self.sync_local_models(wait=wait) def shut_down(self): """Shut down cache purging""" @@ -156,19 +188,19 @@ def load_model( local_model_path: str, model_type: str, wait: bool = True, - aborter: Optional[ActionAborter] = None, retries: Optional[int] = None, - ) -> int: + finder: Optional[Union[str, ModelFinderBase]] = None, + initializer: Optional[Union[str, ModelInitializerBase]] = None, + ) -> LoadedModel: """Load a model using model_path (in Cloud Object Storage) & give it a model ID Args: model_id (str): Model ID string for the model to load. local_model_path (str): Local path to load the model from. model_type (str): Type of the model to load. wait (bool): Wait for the model to finish loading - aborter (Optional[ActionAborter]): The aborter to use for this load - retries: Optional[int]: Number of times to retry on load failure + retries (Optional[int]): Number of times to retry on load failure Returns: - Model_size (int) : Size of the loaded model in bytes + model (LoadedModel): The LoadedModel instance """ with LOAD_MODEL_DURATION_SUMMARY.labels(model_type=model_type).time(): @@ -179,7 +211,7 @@ def load_model( model = self.loaded_models.get(model_id) if model is not None: log.debug("Model '%s' is already loaded", model_id) - return model.size() + return model # Grab the mutation lock and load the model if needed with self._loaded_models_lock: @@ -192,9 +224,10 @@ def load_model( model_id, local_model_path, model_type, - aborter=aborter, fail_callback=partial(self.unload_model, model_id), retries=retries, + finder=finder, + initializer=initializer, ) except Exception as ex: self.__increment_load_model_exception_count_metric(model_type) @@ -224,8 +257,8 @@ def load_model( if wait: model.wait() - # Return the model's size - return model.size() + # Return the loaded model handle + return model def sync_local_models(self, wait: bool = False): """Sync in-memory models with models in the configured local_model_dir @@ -234,7 +267,7 @@ def sync_local_models(self, wait: bool = False): be unloaded. Args: - wait (bool): Wait for loading to complete + wait (bool): After starting all loads, wait for them to complete """ try: self._local_models_dir_sync(wait) @@ -376,8 +409,8 @@ def retrieve_model(self, model_id: str) -> ModuleBase: ) # Now retrieve the model and fall back to lazy loading - model_loaded = model_id in self.loaded_models - if not model_loaded and self._lazy_load_local_models: + loaded_model = self.loaded_models.get(model_id) + if not loaded_model and self._lazy_load_local_models: local_model_path = os.path.join(self._local_models_dir, model_id) log.debug2( "Lazy loading local model %s from %s", model_id, local_model_path @@ -389,17 +422,16 @@ def retrieve_model(self, model_id: str) -> ModuleBase: if not os.path.exists(local_model_path): log.debug2("Attempting to load ephemeral model %s", model_id) local_model_path = model_id - self.load_model( + loaded_model = self.load_model( model_id=model_id, local_model_path=local_model_path, model_type=self._LOCAL_MODEL_TYPE, wait=True, retries=get_config().runtime.lazy_load_retries, ) - model_loaded = True # If still not loaded, there's nothing to find, so raise NOT_FOUND - if not model_loaded: + if not loaded_model: msg = f"Model '{model_id}' not loaded" log.debug( {"log_code": "", "message": msg, "model_id": model_id} @@ -410,7 +442,7 @@ def retrieve_model(self, model_id: str) -> ModuleBase: # NOTE: If the model is partially loaded, this call will wait on the # model future in the LoadedModel - return self.loaded_models[model_id].model() + return loaded_model.model() ## Implementation Details ## @@ -430,8 +462,14 @@ def _local_models_dir_sync(self, wait: bool = False): try: disk_models = os.listdir(self._local_models_dir) except FileNotFoundError as err: + log.error( + "", "Failed to read model ids from disk", exc_info=True + ) raise StopIteration() from err + log.debug3("All models found in local disk cache: %s", disk_models) + log.debug3("Currently loaded models: %s", list(self.loaded_models.keys())) + # Find all models that aren't currently loaded new_models = [ model_id for model_id in disk_models if model_id not in self.loaded_models @@ -443,16 +481,19 @@ def _local_models_dir_sync(self, wait: bool = False): unload_models = [ model_id for model_id, loaded_model in self.loaded_models.items() - if model_id not in disk_models - and loaded_model.path().startswith( - self._local_models_dir, - ) + if loaded_model.path().startswith(self._local_models_dir) + and not os.path.exists(loaded_model.path()) ] log.debug("Unloaded local models: %s", unload_models) # Load new models for model_id in new_models: model_path = os.path.join(self._local_models_dir, model_id) + + if self._model_write_in_progress(model_path): + log.debug("Model %s is still being written", model_id) + continue + self.load_model( model_id, model_path, @@ -483,14 +524,43 @@ def _local_models_dir_sync(self, wait: bool = False): try: loaded_model.wait() except CaikitRuntimeException as err: - log.warning( - "", + log.debug( + "", "Failed to load model %s: %s", model_id, repr(err), exc_info=True, ) + def _model_write_in_progress(self, model_dir: str) -> bool: + """Returns true if model_dir is currently being written to. Uses the + runtime.lazy_load_write_detection_period_seconds configuration to sleep between + consecutive size checks of the directory. + + Always returns false if runtime.lazy_load_write_detection_period_seconds is zero, + negative, or None. + """ + if ( + self._lazy_load_write_detection_period_seconds is None + or self._lazy_load_write_detection_period_seconds <= 0 + ): + return False + + # Get the current directory size + size = self._get_total_disk_size(model_dir) + # Sleep a bit to wait out another write + time.sleep(self._lazy_load_write_detection_period_seconds) + # Get the size again. If it has changed, then a write is currently in progress + return self._get_total_disk_size(model_dir) != size + + @staticmethod + def _get_total_disk_size(model_dir: str) -> int: + """Returns the sum of st_size of all files contained within the directory structure rooted + at model_dir. + """ + dir_path = Path(model_dir) + return sum([f.stat().st_size for f in dir_path.rglob("*") if f.is_file()]) + def __report_total_model_size_metric(self): # Just a happy little lock to ensure that with concurrent loading and unloading, # the last metric reported will be correct. diff --git a/caikit/runtime/model_management/model_sizer.py b/caikit/runtime/model_management/model_sizer.py index 09a8fd0eb..3f5b0c294 100644 --- a/caikit/runtime/model_management/model_sizer.py +++ b/caikit/runtime/model_management/model_sizer.py @@ -106,9 +106,9 @@ def __get_archive_size(self, model_id, local_model_path) -> int: # Probably just an archive file return os.path.getsize(local_model_path) except FileNotFoundError as ex: - message = "Failed to estimate size of model '%s', file '%s' not found" % ( - model_id, - local_model_path, + message = ( + f"Failed to estimate size of model '{model_id}'," + f"file '{local_model_path}' not found" ) log.error("", message) raise CaikitRuntimeException(grpc.StatusCode.NOT_FOUND, message) from ex diff --git a/caikit/runtime/names.py b/caikit/runtime/names.py new file mode 100644 index 000000000..79c10dfcf --- /dev/null +++ b/caikit/runtime/names.py @@ -0,0 +1,321 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +! NOTE ! This file should not import any extra dependencies. It is intended for +use by client libraries that do not necessarily use a specific runtime server +type. +""" + +# Standard +from enum import Enum +from typing import Optional, Type, Union +import re + +# First Party +import alog + +# Local +from caikit.config import get_config +from caikit.core.modules import ModuleBase +from caikit.core.task import TaskBase +from caikit.core.toolkit.name_tools import snake_to_upper_camel +from caikit.interfaces.runtime.data_model import ( + ModelInfoRequest, + ModelInfoResponse, + RuntimeInfoRequest, + RuntimeInfoResponse, + TrainingInfoRequest, + TrainingStatusResponse, +) + +log = alog.use_channel("RNTM-NAMES") + + +############# Serice Names ############## + + +class ServiceType(Enum): + """Common class for describing service types""" + + INFERENCE = 1 # Inference service for the GlobalPredictServicer + TRAINING = 2 # Training service for the GlobalTrainServicer + TRAINING_MANAGEMENT = 3 + INFO = 4 + + +############# Serice Name Generation ############## + + +## Service Package Descriptors + + +def get_ai_domain() -> str: + """Get the string name for the AI domain + + Returns: + domain(str): The domain for this service + """ + caikit_config = get_config() + lib = caikit_config.runtime.library + default_ai_domain_name = snake_to_upper_camel(lib.replace("caikit_", "")) + ai_domain_name = ( + caikit_config.runtime.service_generation.domain or default_ai_domain_name + ) + return ai_domain_name + + +def get_service_package_name(service_type: Optional[ServiceType] = None) -> str: + """This helper will get the name of service package + + Args: + service_type Optional[ServiceType]: The Service Type's package name to fetch defaults + to runtime + + Returns: + str: The name of the service package + """ + + # If specific service_type was provided then return their packages + if service_type == ServiceType.INFO: + return INFO_SERVICE_PACKAGE + elif service_type == ServiceType.TRAINING_MANAGEMENT: + return TRAINING_MANAGEMENT_PACKAGE + + caikit_config = get_config() + ai_domain_name = get_ai_domain() + default_package_name = f"caikit.runtime.{ai_domain_name}" + package_name = ( + caikit_config.runtime.service_generation.package or default_package_name + ) + return package_name + + +def get_service_name(service_type: ServiceType) -> str: + """This helper will get the name of the service + + Args: + service_type ServiceType: The Service Type whose name to fetch + + Returns: + str: The name of the service + """ + if service_type == ServiceType.INFERENCE: + return f"{get_ai_domain()}Service" + elif service_type == ServiceType.TRAINING: + return f"{get_ai_domain()}TrainingService" + elif service_type == ServiceType.TRAINING_MANAGEMENT: + return TRAINING_MANAGEMENT_SERVICE_NAME + elif service_type == ServiceType.INFO: + return INFO_SERVICE_NAME + + +## Service RPC Descriptors + + +def get_train_rpc_name(module_class: Type[ModuleBase]) -> str: + """Helper function to convert from the name of a module to the name of the + request RPC function + """ + + # 🌶️🌶️🌶️ The naming scheme for training RPCs probably needs to change. + # This uses the first task from the `tasks` kwarg in the `@caikit.module` decorator. + # This is both: + # - Flaky, since re-ordering that list would be perfectly reasonable and valid to do except + # for the side effect of breaking the training service api + # - Not very intuitive, since a module supporting multiple tasks will have a training + # endpoint that lists only one of them + rpc_name = snake_to_upper_camel( + f"{next(iter(module_class.tasks)).__name__}_{module_class.__name__}_Train" + ) + + if len(module_class.tasks) > 1: + log.warning( + "", + "Multiple tasks detected for training rpc. " + "Module: [%s], Tasks: [%s], RPC name: %s ", + module_class, + module_class.tasks, + rpc_name, + ) + + return rpc_name + + +def get_task_predict_rpc_name( + task_or_module_class: Type[Union[ModuleBase, TaskBase]], + input_streaming: bool = False, + output_streaming: bool = False, +) -> str: + """Helper function to get the name of a task's RPC""" + task_class = ( + next(iter(task_or_module_class.tasks)) + if issubclass(task_or_module_class, ModuleBase) + else task_or_module_class + ) + + if input_streaming and output_streaming: + return snake_to_upper_camel(f"BidiStreaming{task_class.__name__}_Predict") + if output_streaming: + return snake_to_upper_camel(f"ServerStreaming{task_class.__name__}_Predict") + if input_streaming: + return snake_to_upper_camel(f"ClientStreaming{task_class.__name__}_Predict") + return snake_to_upper_camel(f"{task_class.__name__}_Predict") + + +## Service DataModel Name Descriptors + + +def get_train_request_name(module_class: Type[ModuleBase]) -> str: + """Helper function to get the request name of a Train Service""" + return f"{get_train_rpc_name(module_class)}Request" + + +def get_train_parameter_name(module_class: Type[ModuleBase]) -> str: + """Helper function to get the inner request parameter name of a Train Service""" + return f"{get_train_rpc_name(module_class)}Parameters" + + +def get_task_predict_request_name( + task_or_module_class: Type[Union[ModuleBase, TaskBase]], + input_streaming: bool = False, + output_streaming: bool = False, +) -> str: + """Helper function to get the name of an RPC's request data type""" + + task_class = ( + next(iter(task_or_module_class.tasks)) + if issubclass(task_or_module_class, ModuleBase) + else task_or_module_class + ) + + if input_streaming and output_streaming: + return snake_to_upper_camel(f"BidiStreaming{task_class.__name__}_Request") + if output_streaming: + return snake_to_upper_camel(f"ServerStreaming{task_class.__name__}_Request") + if input_streaming: + return snake_to_upper_camel(f"ClientStreaming{task_class.__name__}_Request") + return snake_to_upper_camel(f"{task_class.__name__}_Request") + + +## Service Definitions + +TRAINING_MANAGEMENT_SERVICE_NAME = "TrainingManagement" +TRAINING_MANAGEMENT_PACKAGE = "caikit.runtime.training" +TRAINING_MANAGEMENT_SERVICE_SPEC = { + "service": { + "rpcs": [ + { + "name": "GetTrainingStatus", + "input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name, + }, + { + "name": "CancelTraining", + "input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name, + }, + ] + } +} + +INFO_SERVICE_NAME = "InfoService" +INFO_SERVICE_PACKAGE = "caikit.runtime.info" +INFO_SERVICE_SPEC = { + "service": { + "rpcs": [ + { + "name": "GetRuntimeInfo", + "input_type": RuntimeInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": RuntimeInfoResponse.get_proto_class().DESCRIPTOR.full_name, + }, + { + "name": "GetModelsInfo", + "input_type": ModelInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": ModelInfoResponse.get_proto_class().DESCRIPTOR.full_name, + }, + ] + } +} + + +############### Server Names ############# + +# Invocation metadata key for the model ID, provided by Model Mesh +MODEL_MESH_MODEL_ID_KEY = "mm-model-id" + + +## HTTP Server + +# Endpoint to use for health checks +HEALTH_ENDPOINT = "/health" + +# Endpoint to use for server info +RUNTIME_INFO_ENDPOINT = "/info/version" +MODELS_INFO_ENDPOINT = "/info/models" + +# These keys are used to define the logical sections of the request and response +# data structures. +REQUIRED_INPUTS_KEY = "inputs" +OPTIONAL_INPUTS_KEY = "parameters" +MODEL_ID = "model_id" + +# Stream event types enum +class StreamEventTypes(Enum): + MESSAGE = "message" + ERROR = "error" + + +def get_http_route_name(rpc_name: str) -> str: + """Function to get the http route for a given rpc name + + Args: + rpc_name (str): The name of the Caikit RPC + + Raises: + NotImplementedError: If the RPC is not a Train or Predict RPC + + Returns: + str: The name of the http route for RPC + """ + if rpc_name.endswith("Predict"): + task_name = re.sub( + r"(? str: + """Function to get GRPC name for a given service type and rpc name + + Args: + rpc_name (str): The name of the Caikit RPC + + Returns: + str: The name of the GRPC route for RPC + """ + return f"/{get_service_package_name(service_type)}.{get_service_name(service_type)}/{rpc_name}" diff --git a/caikit/runtime/server_base.py b/caikit/runtime/server_base.py index 182ac657b..19156e0f8 100644 --- a/caikit/runtime/server_base.py +++ b/caikit/runtime/server_base.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Base class with common functionality across all caikit servers""" - # Standard +from concurrent.futures import ThreadPoolExecutor from typing import Optional import abc import signal @@ -27,15 +27,43 @@ # Local from caikit.config import get_config +from caikit.core.exceptions import error_handler from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException +from caikit.runtime.work_management.abortable_context import ThreadInterrupter import caikit log = alog.use_channel("SERVR-BASE") +error = error_handler.get(log) + + +class ServerThreadPool: + """Simple wrapper for all servers to share a single thread pool""" + @staticmethod + def _build_pool() -> ThreadPoolExecutor: + config = caikit.get_config() + # Leave in backwards compatibility for the old runtime.grpc.server_thread_pool_size + # parameter, which many users may have deployed with. + if pool_size := config.runtime.grpc.server_thread_pool_size: + log.info("Using legacy runtime.grpc.server_thread_pool_size configuration") + else: + pool_size = config.runtime.server_thread_pool_size + + error.type_check("", int, pool_size=pool_size) + + pool = ThreadPoolExecutor( + max_workers=pool_size, thread_name_prefix="caikit_runtime" + ) -class RuntimeServerBase(abc.ABC): + return pool + + # py3.9 compatibility: Can't call @staticmethod on class attribute initialization + pool = _build_pool.__get__(object, None)() + + +class RuntimeServerBase(abc.ABC): # pylint: disable=too-many-instance-attributes __doc__ = __doc__ _metrics_server_started = False @@ -76,6 +104,26 @@ def __init__(self, base_port: int, tls_config_override: Optional[aconfig.Config] self.training_service = training_service + # create runtime info service + self.runtime_info_service: Optional[ + ServicePackage + ] = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + + self.thread_pool: ThreadPoolExecutor = ServerThreadPool.pool + # Create an interrupter that can be used to handle request cancellations or timeouts. + # A separate instance is held per-server so each server can handle the lifetime of their + # own interrupter. + self.interrupter: Optional[ThreadInterrupter] = ( + ThreadInterrupter() if self.config.runtime.use_abortable_threads else None + ) + + # Handle interrupts + # NB: This means that stop() methods will be called even if the process is interrupted + # before the start() method is called + self._intercept_interrupt_signal() + @classmethod def _start_metrics_server(cls) -> None: """Start a single instance of the metrics server based on configuration""" @@ -88,11 +136,6 @@ def _start_metrics_server(cls) -> None: start_http_server(get_config().runtime.metrics.port) cls._metrics_server_started = True - def _intercept_interrupt_signal(self) -> None: - """intercept signal handler""" - signal.signal(signal.SIGINT, self.interrupt) - signal.signal(signal.SIGTERM, self.interrupt) - def interrupt(self, signal_, _stack_frame): log.info( "", @@ -101,6 +144,43 @@ def interrupt(self, signal_, _stack_frame): ) self.stop() + def _intercept_interrupt_signal(self) -> None: + """Intercept signal handlers to allow the server to stop on interrupt. + Calling this on a non-main thread has no effect. + This does not override any existing non-default signal handlers, + it will call them all in the reverse order they are registered. + """ + self._add_signal_handler(signal.SIGINT, self.interrupt) + self._add_signal_handler(signal.SIGTERM, self.interrupt) + + @staticmethod + def _add_signal_handler(sig, handler): + def nested_interrupt_builder(*handlers): + """Build and return an interrupt handler that calls all of *handlers""" + + log.debug("Building interrupt handler: %s", handlers) + + def interrupt(signal_, _stack_frame): + for handler in handlers: + # Only call the handler if it is a callable fn that is _not_ a default handler + log.debug("Running interrupt handler: %s", handler) + if ( + handler + and callable(handler) + and handler != signal.SIG_DFL + and handler is not signal.default_int_handler + ): + handler(signal_, _stack_frame) + + return interrupt + + try: + signal.signal(sig, nested_interrupt_builder(handler, signal.getsignal(sig))) + except ValueError: + log.info( + "Unable to register signal handler. Server was started from a non-main thread." + ) + def _shut_down_model_manager(self): """Shared utility for shutting down the model manager""" ModelManager.get_instance().shut_down() diff --git a/caikit/runtime/service_factory.py b/caikit/runtime/service_factory.py index 7c630d8f3..d533474c4 100644 --- a/caikit/runtime/service_factory.py +++ b/caikit/runtime/service_factory.py @@ -13,7 +13,6 @@ # limitations under the License. """This module is responsible for creating service objects for the runtime to consume""" # Standard -from enum import Enum from types import ModuleType from typing import Callable, Dict, Set, Type, Union import dataclasses @@ -38,10 +37,22 @@ from caikit.core.exceptions import error_handler from caikit.core.task import TaskBase from caikit.interfaces.runtime.data_model import ( + ModelInfoRequest, + ModelInfoResponse, + RuntimeInfoRequest, + RuntimeInfoResponse, TrainingInfoRequest, TrainingStatusResponse, ) from caikit.runtime import service_generation +from caikit.runtime.names import ServiceType as InterfaceServiceType +from caikit.runtime.names import ( + get_service_name, + get_service_package_name, + get_task_predict_request_name, + get_train_parameter_name, + get_train_request_name, +) from caikit.runtime.service_generation.rpcs import CaikitRPCBase from caikit.runtime.utils import import_util @@ -66,6 +77,24 @@ } } +INFO_SERVICE_NAME = "InfoService" +INFO_SERVICE_SPEC = { + "service": { + "rpcs": [ + { + "name": "GetRuntimeInfo", + "input_type": RuntimeInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": RuntimeInfoResponse.get_proto_class().DESCRIPTOR.full_name, + }, + { + "name": "GetModelsInfo", + "input_type": ModelInfoRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": ModelInfoResponse.get_proto_class().DESCRIPTOR.full_name, + }, + ] + } +} + @dataclasses.dataclass class ServicePackage: @@ -90,10 +119,7 @@ class ServicePackage: class ServicePackageFactory: """Factory responsible for yielding the correct concrete ServicePackage implementation""" - class ServiceType(Enum): - INFERENCE = 1 # Inference service for the GlobalPredictServicer - TRAINING = 2 # Training service for the GlobalTrainServicer - TRAINING_MANAGEMENT = 3 + ServiceType = InterfaceServiceType @classmethod def get_service_package( @@ -130,13 +156,28 @@ def get_service_package( caikit_rpcs={}, # No caikit RPCs ) + if service_type == cls.ServiceType.INFO: + grpc_service = json_to_service( + name=INFO_SERVICE_NAME, + package="caikit.runtime.info", + json_service_def=INFO_SERVICE_SPEC, + ) + + return ServicePackage( + service=grpc_service.service_class, + descriptor=grpc_service.descriptor, + registration_function=grpc_service.registration_function, + stub_class=grpc_service.client_stub_class, + messages=None, # we don't need messages here + caikit_rpcs={}, # No caikit RPCs + ) + # First make sure we import the data model for the correct library # !!!! This will use the `caikit_library` config _ = import_util.get_data_model() # Get the names for the AI domain and the proto package - ai_domain_name = service_generation.get_ai_domain() - package_name = service_generation.get_runtime_service_package() + package_name = get_service_package_name(service_type) # Then do API introspection to come up with all the API definitions to support caikit_config = get_config() @@ -144,6 +185,7 @@ def get_service_package( caikit_config, caikit_config.runtime.library, write_modules_file ) + service_name = get_service_name(service_type) if service_type == cls.ServiceType.INFERENCE: # Assert for backwards compatibility, if enabled, when service type is INFERENCE ServicePackageFactory._check_backwards_compatibility( @@ -153,10 +195,8 @@ def get_service_package( rpc_list = service_generation.create_inference_rpcs( clean_modules, caikit_config ) - service_name = f"{ai_domain_name}Service" else: # service_type == cls.ServiceType.TRAINING rpc_list = service_generation.create_training_rpcs(clean_modules) - service_name = f"{ai_domain_name}TrainingService" rpc_list = [rpc for rpc in rpc_list if rpc.return_type is not None] @@ -204,7 +244,7 @@ def _check_backwards_compatibility( "prev_modules_path {} is not a valid file path or is missing permissions", prev_modules_path, ) - with open(prev_modules_path, "r", encoding="utf-8") as f: + with open(prev_modules_path, encoding="utf-8") as f: previous_modules = json.load(f) previous_included_task_map = previous_modules["included_modules"] for task_module in previous_included_task_map.values(): @@ -318,20 +358,12 @@ def get_inference_request( ModuleBase, TaskBase, ) - task_class = ( - next(iter(task_or_module_class.tasks)) - if issubclass(task_or_module_class, ModuleBase) - else task_or_module_class - ) - if input_streaming and output_streaming: - request_class_name = f"BidiStreaming{task_class.__name__}Request" - elif input_streaming: - request_class_name = f"ClientStreaming{task_class.__name__}Request" - elif output_streaming: - request_class_name = f"ServerStreaming{task_class.__name__}Request" - else: - request_class_name = f"{task_class.__name__}Request" + request_class_name = get_task_predict_request_name( + task_or_module_class, + input_streaming=input_streaming, + output_streaming=output_streaming, + ) log.debug( "Request class name %s for class %s.", request_class_name, task_or_module_class ) @@ -345,10 +377,7 @@ def get_train_request(module_class: Type[ModuleBase]) -> Type[DataBase]: module_class, ModuleBase, ) - # 🌶️🌶️🌶️ This is coupled to the naming scheme code in - # caikit.runtime.service_generation.rpcs, which is likely to change. - first_task = next(iter(module_class.tasks)) - request_class_name = f"{first_task.__name__}{module_class.__name__}TrainRequest" + request_class_name = get_train_request_name(module_class) log.debug("Request class name %s for module %s.", request_class_name, module_class) return DataBase.get_class_for_name(request_class_name) @@ -360,8 +389,7 @@ def get_train_params(module_class: Type[ModuleBase]) -> Type[DataBase]: module_class, ModuleBase, ) - first_task = next(iter(module_class.tasks)) - request_class_name = f"{first_task.__name__}{module_class.__name__}TrainParameters" + request_class_name = get_train_parameter_name(module_class) log.debug("Request class name %s for module %s.", request_class_name, module_class) return DataBase.get_class_for_name(request_class_name) diff --git a/caikit/runtime/service_generation/__init__.py b/caikit/runtime/service_generation/__init__.py index ebbe62808..18a79a7e7 100644 --- a/caikit/runtime/service_generation/__init__.py +++ b/caikit/runtime/service_generation/__init__.py @@ -5,4 +5,3 @@ create_inference_rpcs, create_training_rpcs, ) -from .proto_package import get_ai_domain, get_runtime_service_package diff --git a/caikit/runtime/service_generation/create_service.py b/caikit/runtime/service_generation/create_service.py index debbd231a..b2109cbdc 100644 --- a/caikit/runtime/service_generation/create_service.py +++ b/caikit/runtime/service_generation/create_service.py @@ -110,7 +110,7 @@ def create_inference_rpcs( exc_info=True, ) - return rpcs + return sorted(rpcs, key=lambda x: x.name) def create_training_rpcs(modules: List[Type[ModuleBase]]) -> List[CaikitRPCBase]: @@ -156,7 +156,7 @@ def create_training_rpcs(modules: List[Type[ModuleBase]]) -> List[CaikitRPCBase] err, exc_info=True, ) - return rpcs + return sorted(rpcs, key=lambda x: x.name) def _group_modules_by_task( diff --git a/caikit/runtime/service_generation/data_stream_source.py b/caikit/runtime/service_generation/data_stream_source.py index fcbd216e0..9b6576686 100644 --- a/caikit/runtime/service_generation/data_stream_source.py +++ b/caikit/runtime/service_generation/data_stream_source.py @@ -41,7 +41,7 @@ ListOfFileReferences, S3Files, ) -from caikit.runtime.service_generation.proto_package import get_runtime_service_package +from caikit.runtime.names import get_service_package_name from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException import caikit @@ -321,7 +321,7 @@ def get_stream_message_type(self, element_type: type) -> Type[DataBase]: if stream_message_type: return stream_message_type - package = get_runtime_service_package() + package = get_service_package_name() cls_name = _make_data_stream_source_type_name(element_type) JsonData = make_dataobject( package=package, @@ -422,8 +422,9 @@ def __init__(self): super().__init__(self._generator) def _generator(self): - stream = self.to_data_stream() - return stream.generator_func(*stream.generator_args, **stream.generator_kwargs) + return self._stream.generator_func( + *self._stream.generator_args, **self._stream.generator_kwargs + ) def __getstate__(self) -> bytes: """A DataStreamSource is pickled by serializing its source @@ -450,6 +451,13 @@ def name_to_plugin_map(self): plugin.get_field_name(self.ELEMENT_TYPE): plugin for plugin in self.PLUGINS } + @cached_property + def _stream(self): + """The internal _stream is cached here so that the result of calling to_data_stream can be + re-read, rather than requiring to_data_stream to be invoked on every read through the + stream""" + return self.to_data_stream() + # pylint: disable=too-many-return-statements def to_data_stream(self) -> DataStream: """Convert to the target data stream type based on the source type""" @@ -498,7 +506,7 @@ def make_data_stream_source( log.debug2("Looking for DataStreamSource[%s]", data_element_type) if data_element_type not in _DATA_STREAM_SOURCE_TYPES: cls_name = _make_data_stream_source_type_name(data_element_type) - package = get_runtime_service_package() + package = get_service_package_name() log.debug("Creating DataStreamSource[%s] -> %s", data_element_type, cls_name) @@ -589,7 +597,7 @@ def __init__(self, *args, **kwargs): ) from err DataStreamSourceBase.__init__(self) - setattr(data_object, "__init__", __init__) + data_object.__init__ = __init__ _DATA_STREAM_SOURCE_TYPES[data_element_type] = data_object diff --git a/caikit/runtime/service_generation/proto_package.py b/caikit/runtime/service_generation/proto_package.py deleted file mode 100644 index 99adb088a..000000000 --- a/caikit/runtime/service_generation/proto_package.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This module holds the common helpers for managing the protobuf package name for -the runtime -""" - -# Local -from ...config import get_config - - -def snake_to_upper_camel(string: str) -> str: - """Simple snake -> upper camel conversion""" - return "".join([part[0].upper() + part[1:] for part in string.split("_")]) - - -def get_ai_domain() -> str: - """Get the string name for the AI domain""" - caikit_config = get_config() - lib = caikit_config.runtime.library - default_ai_domain_name = snake_to_upper_camel(lib.replace("caikit_", "")) - ai_domain_name = ( - caikit_config.runtime.service_generation.domain or default_ai_domain_name - ) - return ai_domain_name - - -def get_runtime_service_package() -> str: - """This helper will get the common runtime package""" - caikit_config = get_config() - ai_domain_name = get_ai_domain() - default_package_name = f"caikit.runtime.{ai_domain_name}" - package_name = ( - caikit_config.runtime.service_generation.package or default_package_name - ) - return package_name diff --git a/caikit/runtime/service_generation/protoable.py b/caikit/runtime/service_generation/protoable.py index ad5f44232..81e3ad4d0 100644 --- a/caikit/runtime/service_generation/protoable.py +++ b/caikit/runtime/service_generation/protoable.py @@ -179,14 +179,14 @@ def is_protoable_type(arg_type: Type) -> bool: protoable = True elif typing.get_origin(arg_type) == list: log.debug2("Arg is List") - if len(typing.get_args(arg_type)) == 0: + if not typing.get_args(arg_type): log.debug2("List annotation has no type") protoable = False else: protoable = is_protoable_type(typing.get_args(arg_type)[0]) elif typing.get_origin(arg_type) == dict: log.debug2("Arg is Dict") - if len(typing.get_args(arg_type)) == 0: + if not typing.get_args(arg_type): log.debug2("Dict annotation has no type") protoable = False else: diff --git a/caikit/runtime/service_generation/rpcs.py b/caikit/runtime/service_generation/rpcs.py index c67ceacbf..ca44bab10 100644 --- a/caikit/runtime/service_generation/rpcs.py +++ b/caikit/runtime/service_generation/rpcs.py @@ -34,12 +34,18 @@ from . import protoable, type_helpers from .compatibility_checker import ApiFieldNames from .data_stream_source import make_data_stream_source -from .proto_package import snake_to_upper_camel from caikit.core import ModuleBase, TaskBase from caikit.core.data_model.base import DataBase from caikit.core.data_model.dataobject import make_dataobject from caikit.core.signature_parsing import CaikitMethodSignature, CustomSignature from caikit.interfaces.runtime.data_model import ModelPointer, TrainingJob +from caikit.runtime.names import ( + get_task_predict_request_name, + get_task_predict_rpc_name, + get_train_parameter_name, + get_train_request_name, + get_train_rpc_name, +) log = alog.use_channel("RPC-SERIALIZERS") @@ -152,28 +158,7 @@ def module_class_to_rpc_name(module_class: Type[ModuleBase]) -> str: """Helper function to convert from the name of a module to the name of the request RPC function """ - # 🌶️🌶️🌶️ The naming scheme for training RPCs probably needs to change. - # This uses the first task from the `tasks` kwarg in the `@caikit.module` decorator. - # This is both: - # - Flaky, since re-ordering that list would be perfectly reasonable and valid to do except - # for the side effect of breaking the training service api - # - Not very intuitive, since a module supporting multiple tasks will have a training - # endpoint that lists only one of them - rpc_name = snake_to_upper_camel( - f"{next(iter(module_class.tasks)).__name__}_{module_class.__name__}_Train" - ) - - if len(module_class.tasks) > 1: - log.warning( - "", - "Multiple tasks detected for training rpc. " - "Module: [%s], Tasks: [%s], RPC name: %s ", - module_class, - module_class.tasks, - rpc_name, - ) - - return rpc_name + return get_train_rpc_name(module_class) @staticmethod def module_class_to_req_name(module_class: Type[ModuleBase]) -> str: @@ -185,7 +170,7 @@ def module_class_to_req_name(module_class: Type[ModuleBase]) -> str: return: SampleTaskSampleModuleTrainRequest """ - return f"{ModuleClassTrainRPC.module_class_to_rpc_name(module_class)}Request" + return get_train_request_name(module_class) @staticmethod def module_class_to_inner_request_name(module_class: Type[ModuleBase]) -> str: @@ -196,7 +181,7 @@ def module_class_to_inner_request_name(module_class: Type[ModuleBase]) -> str: return: SampleTaskSampleModuleTrainParameters """ - return f"{ModuleClassTrainRPC.module_class_to_rpc_name(module_class)}Parameters" + return get_train_parameter_name(module_class) @staticmethod def _mutate_method_signature_for_training( @@ -348,23 +333,16 @@ def _handle_task_inputs(self, method_params: Dict[str, Any]) -> Dict[str, Any]: # for unary input cases req_params = self.task.get_required_parameters(input_streaming=False) for param_name, param_type in method_params.items(): - if param_name in req_params: - new_params[param_name] = req_params[param_name] - else: - new_params[param_name] = param_type + new_params[param_name] = req_params.get(param_name, param_type) return new_params def _task_to_req_name(self) -> str: """Helper function to convert the pair of library name and task name to a request message name """ - if self._input_streaming and self._output_streaming: - return snake_to_upper_camel(f"BidiStreaming{self.task.__name__}_Request") - if self._output_streaming: - return snake_to_upper_camel(f"ServerStreaming{self.task.__name__}_Request") - if self._input_streaming: - return snake_to_upper_camel(f"ClientStreaming{self.task.__name__}_Request") - return snake_to_upper_camel(f"{self.task.__name__}_Request") + return get_task_predict_request_name( + self.task, self._input_streaming, self._output_streaming + ) def _task_to_rpc_name(self) -> str: """Helper function to convert the pair of library name and task name @@ -374,13 +352,9 @@ def _task_to_rpc_name(self) -> str: return: SampleTaskPredict """ - if self._input_streaming and self._output_streaming: - return snake_to_upper_camel(f"BidiStreaming{self.task.__name__}_Predict") - if self._output_streaming: - return snake_to_upper_camel(f"ServerStreaming{self.task.__name__}_Predict") - if self._input_streaming: - return snake_to_upper_camel(f"ClientStreaming{self.task.__name__}_Predict") - return snake_to_upper_camel(f"{self.task.__name__}_Predict") + return get_task_predict_rpc_name( + self.task, self._input_streaming, self._output_streaming + ) class _RequestMessage: @@ -401,10 +375,7 @@ def __init__( existing_fields = ApiFieldNames.get_fields_for_message(self.name) - if len(existing_fields) > 0: - last_used_number = max(existing_fields.values()) - else: - last_used_number = 0 + last_used_number = max(existing_fields.values()) if existing_fields else 0 for _, (item_name, typ) in enumerate(params.items()): if item_name in existing_fields: diff --git a/caikit/runtime/servicers/global_predict_servicer.py b/caikit/runtime/servicers/global_predict_servicer.py index db79a3402..daf58168b 100644 --- a/caikit/runtime/servicers/global_predict_servicer.py +++ b/caikit/runtime/servicers/global_predict_servicer.py @@ -29,11 +29,12 @@ # Local from caikit import get_config -from caikit.core import ModuleBase +from caikit.core import ModuleBase, TaskBase from caikit.core.data_model import DataBase, DataStream from caikit.core.signature_parsing import CaikitMethodSignature from caikit.runtime.metrics.rpc_meter import RPCMeter from caikit.runtime.model_management.model_manager import ModelManager +from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY from caikit.runtime.service_factory import ServicePackage from caikit.runtime.service_generation.rpcs import TaskPredictRPC from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -45,7 +46,10 @@ get_metadata, validate_data_model, ) -from caikit.runtime.work_management.abortable_action import AbortableAction +from caikit.runtime.work_management.abortable_context import ( + AbortableContext, + ThreadInterrupter, +) from caikit.runtime.work_management.rpc_aborter import RpcAborter PREDICT_RPC_COUNTER = Counter( @@ -84,16 +88,13 @@ class GlobalPredictServicer: given request """ - # Invocation metadata key for the model ID, provided by Model Mesh - MODEL_MESH_MODEL_ID_KEY = "mm-model-id" - # Input size in code points, provided by orchestrator INPUT_SIZE_KEY = "input-length" def __init__( self, inference_service: ServicePackage, - use_abortable_threads: bool = get_config().runtime.use_abortable_threads, + interrupter: ThreadInterrupter = None, ): self._started_metering = False self._model_manager = ModelManager.get_instance() @@ -111,7 +112,7 @@ def __init__( "Metering is disabled, to enable set `metering.enabled` in config to true", ) - self.use_abortable_threads = use_abortable_threads + self._interrupter = interrupter self._inference_service = inference_service # Validate that the Caikit Library CDM is compatible with our service descriptor validate_data_model(self._inference_service.descriptor) @@ -155,70 +156,72 @@ def Predict( A Caikit Library data model response object """ # Make sure the request has a model before doing anything - model_id = get_metadata(context, self.MODEL_MESH_MODEL_ID_KEY) + model_id = get_metadata(context, MODEL_MESH_MODEL_ID_KEY) request_name = caikit_rpc.request.name - with self._handle_predict_exceptions(model_id, request_name): - with alog.ContextLog( - log.debug, "GlobalPredictServicer.Predict:%s", request_name - ): - # Retrieve the model from the model manager - log.debug("", "Retrieving model '%s'", model_id) - model = self._model_manager.retrieve_model(model_id) - model_class = type(model) - - # Little hackity hack: Calling _verify_model_task upfront here as well to - # short-circuit requests where the model is _totally_ unsupported - self._verify_model_task(model) - - # Unmarshall the request object into the required module run argument(s) - with PREDICT_FROM_PROTO_SUMMARY.labels( - grpc_request=request_name, model_id=model_id - ).time(): - inference_signature = model_class.get_inference_signature( - input_streaming=caikit_rpc.input_streaming, - output_streaming=caikit_rpc.output_streaming, - task=caikit_rpc.task, - ) - if not inference_signature: - raise CaikitRuntimeException( - StatusCode.INVALID_ARGUMENT, - f"Model class {model_class} does not support {caikit_rpc.name}", - ) - if caikit_rpc.input_streaming: - caikit_library_request = ( - self._build_caikit_library_request_stream( - request, inference_signature, caikit_rpc - ) - ) - else: - caikit_library_request = build_caikit_library_request_dict( - request, - inference_signature, - ) - response = self.predict_model( - request_name, - model_id, - inference_func_name=inference_signature.method_name, - aborter=RpcAborter(context) if self.use_abortable_threads else None, - **caikit_library_request, + with self._handle_predict_exceptions(model_id, request_name), alog.ContextLog( + log.debug, "GlobalPredictServicer.Predict:%s", request_name + ): + # Retrieve the model from the model manager + log.debug("", "Retrieving model '%s'", model_id) + model = self._model_manager.retrieve_model(model_id) + model_class = type(model) + + # Little hackity hack: Calling _verify_model_task upfront here as well to + # short-circuit requests where the model is _totally_ unsupported + self._verify_model_task(model) + + # Unmarshall the request object into the required module run argument(s) + with PREDICT_FROM_PROTO_SUMMARY.labels( + grpc_request=request_name, model_id=model_id + ).time(): + inference_signature = model_class.get_inference_signature( + input_streaming=caikit_rpc.input_streaming, + output_streaming=caikit_rpc.output_streaming, + task=caikit_rpc.task, ) + if not inference_signature: + raise CaikitRuntimeException( + StatusCode.INVALID_ARGUMENT, + f"Model class {model_class} does not support {caikit_rpc.name}", + ) + if caikit_rpc.input_streaming: + caikit_library_request = self._build_caikit_library_request_stream( + request, inference_signature, caikit_rpc + ) + else: + caikit_library_request = build_caikit_library_request_dict( + request, + inference_signature, + ) + response = self.predict_model( + request_name, + model_id, + input_streaming=caikit_rpc.input_streaming, + output_streaming=caikit_rpc.output_streaming, + task=caikit_rpc.task, + aborter=RpcAborter(context) if self._interrupter else None, + **caikit_library_request, + ) - # Marshall the response to the necessary return type - with PREDICT_TO_PROTO_SUMMARY.labels( - grpc_request=request_name, model_id=model_id - ).time(): - if caikit_rpc.output_streaming: - response_proto = build_proto_stream(response) - else: - response_proto = build_proto_response(response) - return response_proto + # Marshall the response to the necessary return type + with PREDICT_TO_PROTO_SUMMARY.labels( + grpc_request=request_name, model_id=model_id + ).time(): + if caikit_rpc.output_streaming: + response_proto = build_proto_stream(response) + else: + response_proto = build_proto_response(response) + return response_proto def predict_model( self, request_name: str, model_id: str, inference_func_name: str = "run", + input_streaming: Optional[bool] = None, + output_streaming: Optional[bool] = None, + task: Optional[TaskBase] = None, aborter: Optional[RpcAborter] = None, **kwargs, ) -> Union[DataBase, Iterable[DataBase]]: @@ -231,7 +234,14 @@ def predict_model( model_id (str): The ID of the loaded model inference_func_name (str): - The name of the inference function to run + Explicit name of the inference function to predict (ignored if + input_streaming and output_streaming set) + input_streaming (Optional[bool]): + Use the task function with input streaming + output_streaming (Optional[bool]): + Use the task function with output streaming + task (Optional[TaskBase]) + The task to use for inference (if multitask model) aborter (Optional[RpcAborter]): If using abortable calls, this is the aborter to use **kwargs: Keyword arguments to pass to the model's run function @@ -244,23 +254,29 @@ def predict_model( with self._handle_predict_exceptions(model_id, request_name): model = self._model_manager.retrieve_model(model_id) self._verify_model_task(model) - + if input_streaming is not None and output_streaming is not None: + inference_func_name = model.get_inference_signature( + output_streaming=output_streaming, + input_streaming=input_streaming, + task=task, + ).method_name + log.debug2("Deduced inference function name: %s", inference_func_name) + + model_run_fn = getattr(model, inference_func_name) # NB: we previously recorded the size of the request, and timed this module to # provide a rudimentary throughput metric of size / time + # 🌶️🌶️🌶️ The `AbortableContext` will only abort if both `self._interrupter` and + # `aborter` are set with alog.ContextLog( log.debug, "GlobalPredictServicer.Predict.caikit_library_run:%s", request_name, + ), PREDICT_CAIKIT_LIBRARY_SUMMARY.labels( + grpc_request=request_name, model_id=model_id + ).time(), AbortableContext( + aborter, self._interrupter ): - model_run_fn = getattr(model, inference_func_name) - with PREDICT_CAIKIT_LIBRARY_SUMMARY.labels( - grpc_request=request_name, model_id=model_id - ).time(): - if aborter is not None: - work = AbortableAction(aborter, model_run_fn, **kwargs) - response = work.do() - else: - response = model_run_fn(**kwargs) + response = model_run_fn(**kwargs) # Update Prometheus metrics PREDICT_RPC_COUNTER.labels( @@ -391,7 +407,7 @@ def call_build_request_dict(request: ProtobufMessage) -> Dict[str, Any]: ) stream_num += 1 - for param in streaming_params.keys(): + for param in streaming_params: # For each "streaming" parameter, grab one of the tee'd streams and map it to return # a `DataStream` of that individual parameter diff --git a/caikit/runtime/servicers/global_train_servicer.py b/caikit/runtime/servicers/global_train_servicer.py index 507dfca17..3ebd61a37 100644 --- a/caikit/runtime/servicers/global_train_servicer.py +++ b/caikit/runtime/servicers/global_train_servicer.py @@ -297,12 +297,3 @@ def rpc_termination_callback(): model_name=request.model_name, training_id=model_future.id, ) - - def _load_trained_model(self, model_name: str, model_path: str): - log.debug("Autoloading trained model %s", model_name) - self._model_manager.load_model( - model_id=model_name, - local_model_path=model_path, - model_type="standalone", - ) - return self._model_manager.retrieve_model(model_name) diff --git a/caikit/runtime/servicers/info_servicer.py b/caikit/runtime/servicers/info_servicer.py new file mode 100644 index 000000000..ca7b68635 --- /dev/null +++ b/caikit/runtime/servicers/info_servicer.py @@ -0,0 +1,156 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the implementation for retrieving information about the +library and services. +""" +# Have pylint ignore Class XXXX has no YYYY member so that we can use gRPC enums. +# pylint: disable=E1101 + +# Standard +from typing import Any, Dict, List, Optional, Union + +# Third Party +from grpc import StatusCode +import importlib_metadata + +# First Party +import alog + +# Local +from caikit.config import get_config +from caikit.interfaces.runtime.data_model import ( + ModelInfo, + ModelInfoRequest, + ModelInfoResponse, + RuntimeInfoResponse, +) +from caikit.runtime.model_management.model_manager import ModelManager +from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException + +log = alog.use_channel("RI-SERVICR-I") + + +class InfoServicer: + """This class contains the implementation for retrieving information about the + library and services.""" + + def GetModelsInfo( + self, request: ModelInfoRequest, context # pylint: disable=unused-argument + ) -> ModelInfoResponse: + """Get information on the loaded models for the GRPC server + + Args: + request: ModelInfoRequest + context + + Returns: + models_info: ModelInfoResponse + DataObject containing the model info + """ + return self._get_models_info(model_ids=request.model_ids).to_proto() + + def get_models_info_dict( + self, model_ids: Optional[List[str]] + ) -> Dict[str, List[Dict[str, Any]]]: + """Get information on models for the HTTP server + + Returns: + model_info_dict: Dict[str, List[Dict[str, str]]] + Dict representation of ModelInfoResponse + """ + return self._get_models_info(model_ids=model_ids).to_dict() + + def _get_models_info( + self, model_ids: Optional[List[str]] = None + ) -> ModelInfoResponse: + """Helper function to get the list of models + + Returns: + model_info: ModelInfoResponse + DataObject with model information + """ + model_manager = ModelManager.get_instance() + + # Get list of models based on input list or all loaded models + loaded_model_list = [] + if model_ids: + for model_name in model_ids: + loaded_model = model_manager.loaded_models.get(model_name) + if not loaded_model: + raise CaikitRuntimeException( + StatusCode.NOT_FOUND, f"Model {model_name} is not loaded" + ) + + loaded_model_list.append((model_name, loaded_model)) + else: + loaded_model_list = model_manager.loaded_models.items() + + # Get all loaded models + response = ModelInfoResponse(models=[]) + for name, loaded_module in loaded_model_list: + model_instance = loaded_module.model() + response.models.append( + ModelInfo( + model_path=loaded_module.path(), + name=name, + size=loaded_module.size(), + metadata=model_instance.public_model_info, + module_id=model_instance.MODULE_ID, + module_metadata=model_instance.module_metadata, + ) + ) + return response + + def GetRuntimeInfo( + self, request, context # pylint: disable=unused-argument + ) -> RuntimeInfoResponse: + """Get information on versions of libraries and server for GRPC""" + return self._get_runtime_info().to_proto() + + def get_version_dict(self) -> Dict[str, Union[str, Dict]]: + """Get information on versions of libraries and server for HTTP""" + return self._get_runtime_info().to_dict() + + def _get_runtime_info(self) -> RuntimeInfoResponse: + """Get information on versions of libraries and server from config""" + config_version_info = get_config().runtime.version_info or {} + python_packages = { + package: version + for package, version in config_version_info.get( + "python_packages", {} + ).items() + if package != "all" + } + all_packages = (config_version_info.get("python_packages") or {}).get("all") + + for lib, dist_names in importlib_metadata.packages_distributions().items(): + if ( + all_packages or (len(lib.split(".")) == 1 and lib.startswith("caikit")) + ) and (version := self._try_lib_version(dist_names[0])): + python_packages[lib] = version + + runtime_image = config_version_info.get("runtime_image") + + return RuntimeInfoResponse( + python_packages=python_packages, + runtime_version=runtime_image, + ) + + def _try_lib_version(self, name) -> str: + """Get version of python modules""" + try: + return importlib_metadata.version(name) + except importlib_metadata.PackageNotFoundError: + return None diff --git a/caikit/runtime/servicers/model_runtime_servicer.py b/caikit/runtime/servicers/model_runtime_servicer.py index ac8460248..c2d50c381 100644 --- a/caikit/runtime/servicers/model_runtime_servicer.py +++ b/caikit/runtime/servicers/model_runtime_servicer.py @@ -24,7 +24,10 @@ from caikit.runtime.protobufs import model_runtime_pb2, model_runtime_pb2_grpc from caikit.runtime.types.aborted_exception import AbortedException from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException -from caikit.runtime.work_management.abortable_action import AbortableAction +from caikit.runtime.work_management.abortable_context import ( + AbortableContext, + ThreadInterrupter, +) from caikit.runtime.work_management.rpc_aborter import RpcAborter log = alog.use_channel("MR-SERVICR-I") @@ -34,8 +37,9 @@ class ModelRuntimeServicerImpl(model_runtime_pb2_grpc.ModelRuntimeServicer): """This class contains the implementation of all of the RPCs that are required to run a service in Model Mesh as a Model-Runtime.""" - def __init__(self): + def __init__(self, interrupter: ThreadInterrupter = None): self.model_manager = ModelManager.get_instance() + self.interrupter = interrupter def loadModel(self, request, context): """Model loading . @@ -51,32 +55,21 @@ def loadModel(self, request, context): log.info( { "log_code": "", - "message": "Loading model '%s'" % request.modelId, + "message": f"Loading model '{request.modelId}'", "model_id": request.modelId, } ) - caikit_config = get_config() - if caikit_config.runtime.use_abortable_threads: - aborter = RpcAborter(context) - work = AbortableAction( - aborter, - self.model_manager.load_model, - request.modelId, - request.modelPath, - request.modelType, - aborter=aborter, - ) - model_size = work.do() - else: - model_size = self.model_manager.load_model( + aborter = RpcAborter(context) if self.interrupter else None + with AbortableContext(aborter=aborter, interrupter=self.interrupter): + loaded_model = self.model_manager.load_model( request.modelId, request.modelPath, request.modelType ) + model_size = loaded_model.size() log.info( { "log_code": "", - "message": "Model '%s' loaded! Model size [%s]" - % (request.modelId, str(model_size)), + "message": f"Model '{request.modelId}' loaded! Model size [{model_size}]", "model_id": request.modelId, } ) @@ -85,8 +78,7 @@ def loadModel(self, request, context): log.warning( { "log_code": "", - "message": "Model '%s' was not loaded due to the rpc aborting" - % request.modelId, + "message": f"Model '{request.modelId}' was not loaded due to the rpc aborting", "model_id": request.modelId, "error_id": e.id, } @@ -101,8 +93,10 @@ def loadModel(self, request, context): log.warning( { "log_code": "", - "message": "Model '%s' could not be loaded! Reason: [%s]" - % (request.modelId, str(e.message)), + "message": ( + f"Model '{request.modelId}' could not be loaded!" + f"Reason: [{e.message}]" + ), "model_id": request.modelId, "error_id": e.id, } @@ -137,7 +131,7 @@ def unloadModel(self, request, context): log.info( { "log_code": "", - "message": "Unloading model '%s'" % request.modelId, + "message": f"Unloading model '{request.modelId}'", "model_id": request.modelId, } ) @@ -145,8 +139,7 @@ def unloadModel(self, request, context): log.info( { "log_code": "", - "message": "Unloaded model '%s' (Reclaimed size: %s)" - % (request.modelId, model_size), + "message": f"Unloaded model '{request.modelId}' (Reclaimed size: {model_size})", "model_id": request.modelId, } ) @@ -154,8 +147,10 @@ def unloadModel(self, request, context): log.warning( { "log_code": "", - "message": "Model '%s' could not be unloaded! Reason: [%s]" - % (request.modelId, str(e.message)), + "message": ( + f"Model '{request.modelId}' could not be unloaded!" + f"Reason: [{e.message}]" + ), "model_id": request.modelId, "error_id": e.id, } @@ -178,7 +173,7 @@ def predictModelSize(self, request, context): log.info( { "log_code": "", - "message": "Predicting size of model '%s'" % request.modelId, + "message": f"Predicting size of model '{request.modelId}'", "model_id": request.modelId, } ) @@ -188,8 +183,7 @@ def predictModelSize(self, request, context): log.info( { "log_code": "", - "message": "Predicted model '%s' size: [%s]" - % (request.modelId, str(predicted_size)), + "message": f"Predicted model '{request.modelId}' size: [{predicted_size}]", "model_id": request.modelId, } ) @@ -198,8 +192,10 @@ def predictModelSize(self, request, context): log.warning( { "log_code": "", - "message": "Model '%s' size could not be predicted! Reason: [%s]" - % (request.modelId, e.message), + "message": ( + f"Model '{request.modelId}' size could not be predicted!" + f"Reason: [e.message]" + ), "model_id": request.modelId, "error_id": e.id, } @@ -222,7 +218,7 @@ def modelSize(self, request, context): log.info( { "log_code": "", - "message": "Computing size of model '%s'" % request.modelId, + "message": f"Computing size of model '{request.modelId}'", "model_id": request.modelId, } ) @@ -230,8 +226,7 @@ def modelSize(self, request, context): log.info( { "log_code": "", - "message": "Computed model '%s' size: [%s]" - % (request.modelId, str(model_size)), + "message": f"Computed model '{request.modelId}' size: [{model_size}]", "model_id": request.modelId, } ) @@ -239,8 +234,10 @@ def modelSize(self, request, context): log.warning( { "log_code": "", - "message": "Failed to calculate model '%s' size! Reason: [%s]" - % (request.modelId, e.message), + "message": ( + f"Failed to calculate model '{request.modelId}' size!" + f"Reason: [{e.message}]" + ), "model_id": request.modelId, "error_id": e.id, } diff --git a/caikit/runtime/servicers/model_train_servicer.py b/caikit/runtime/servicers/model_train_servicer.py index 2af53c925..6bf41f02b 100644 --- a/caikit/runtime/servicers/model_train_servicer.py +++ b/caikit/runtime/servicers/model_train_servicer.py @@ -188,7 +188,7 @@ def _update_file_references( ) training_request_data_model = dm_class.from_proto(train_message_request) # 2. Find any data streams - for attr_name in training_request_data_model.__annotations__.keys(): + for attr_name in training_request_data_model.__annotations__: val = getattr(training_request_data_model, attr_name) if isinstance(val, DataStreamSourceBase): # 3. Look for file pointers and update them diff --git a/caikit/runtime/types/aborted_exception.py b/caikit/runtime/types/aborted_exception.py index 1360e1c24..114b9e43f 100644 --- a/caikit/runtime/types/aborted_exception.py +++ b/caikit/runtime/types/aborted_exception.py @@ -25,5 +25,10 @@ class AbortedException(CaikitRuntimeException): message: Because Exceptions usually have those """ - def __init__(self, message): + def __init__(self, message: str = None): + if not message: + message = ( + "Work in this thread was aborted by a context manager. " + "This is usually due to a client timeout or cancellation." + ) super().__init__(grpc.StatusCode.ABORTED, message) diff --git a/caikit/runtime/utils/import_util.py b/caikit/runtime/utils/import_util.py index 864a9e6b2..fb3f74b23 100644 --- a/caikit/runtime/utils/import_util.py +++ b/caikit/runtime/utils/import_util.py @@ -161,10 +161,7 @@ def get_dynamic_module(module_name: str, module_dir: str = None) -> ModuleType: Returns: (module): Handle to the module after dynamic import """ - if module_dir: - module_path = module_dir + "." + module_name - else: - module_path = module_name + module_path = f"{module_dir}.{module_name}" if module_dir else module_name log.info("", "Loading service module: %s", module_path) # Try to find the spec for the module that we're interested in. spec = importlib.util.find_spec(module_path) diff --git a/caikit/runtime/utils/servicer_util.py b/caikit/runtime/utils/servicer_util.py index 5bce68303..033df252b 100644 --- a/caikit/runtime/utils/servicer_util.py +++ b/caikit/runtime/utils/servicer_util.py @@ -220,7 +220,6 @@ def validate_data_model( # and check to make that it is either a primitive protobufs type or that # we have a data model class that we can deserialize the protobufs with if not is_protobuf_primitive_field(field): - if field.message_type and field.message_type.GetOptions().map_entry: log.debug( "", @@ -400,9 +399,7 @@ def build_caikit_library_request_dict( # 2. Remove any fields not in the module signature absent_field_names = [ - field - for field in kwargs_dict.keys() - if field not in module_signature.parameters.keys() + field for field in kwargs_dict if field not in module_signature.parameters ] for absent_field_name in absent_field_names: kwargs_dict.pop(absent_field_name) @@ -410,7 +407,6 @@ def build_caikit_library_request_dict( # 3. Handle type conversions updated_kwargs = {} for field_name, field_value in kwargs_dict.items(): - # 3.1 Model Pointers if isinstance(field_value, ModelPointer): log.debug2("field %s value is a ModelPointer obj", field_name) diff --git a/caikit/runtime/work_management/abortable_action.py b/caikit/runtime/work_management/abortable_action.py deleted file mode 100644 index b561d27d1..000000000 --- a/caikit/runtime/work_management/abortable_action.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import Callable -import abc -import threading - -# First Party -import alog - -# Local -from caikit.core.toolkit.destroyable_thread import DestroyableThread -from caikit.runtime.types.aborted_exception import AbortedException - -log = alog.use_channel("ABORT-ACTION") - - -class ActionAborter(abc.ABC): - """Simple interface to wrap up a notification that an action must abort. - - Children of this class can bind to any notification tool (e.g. grpc context) - """ - - @abc.abstractmethod - def must_abort(self) -> bool: - """Indicate whether or not the action must be aborted""" - - @abc.abstractmethod - def add_event(self, event: threading.Event): - """Add an event to notify when abort happens""" - - -class AbortableAction: - """A class for Abortable Actions. We want actions that are computationally heavy to be - abortable by Model Mesh! Currently, we use this for the following operations. - - - Loading a model - - Predicting with a model - - Training a model - - In the future, this may include getting the size of a model, depending on how that we choose - to implement that. - - How it works: - Instances of this class create a threading.Event, which will be used to signal that either: - - The RPC was terminated - - The heavy work that we wanted to complete is done - This is done by using a RpcAborter and a DestroyableThread. - Registering the event with the RpcAborter will cause it to set when the RPC is - terminated, and creating a DestroyableThread with the event will cause it to set when - the thread terminates. - - The action will start the DestroyableThread and then wait on the event. When it wakes, it - will check the reason and destroy the thread if it was woken by the RpcAborter or return - the result if it was woken by the thread completing. - """ - - def __init__( - self, - call_aborter: ActionAborter, - runnable_func: Callable, - *args, - **kwargs, - ): - """ - Args: - call_aborter - call aborter capable of aborting the runnable_func - runnable_func - the function to be run as an abortable action - *args - nonkeyword arguments to runnable_func - **kwargs - keyword arguments to runnable_func""" - - # Create new event to watch for both RPC termination and work completion - self.__done_or_aborted_event = threading.Event() - - # Register the event with our call aborter so it fires if the RPC terminates - self.call_aborter = call_aborter - self.call_aborter.add_event(self.__done_or_aborted_event) - - # Create a new thread to do the work, which will set the event if it finished - self.__runnable_func = runnable_func - self.__work_thread = DestroyableThread( - self.__runnable_func, - *args, - work_done_event=self.__done_or_aborted_event, - **kwargs, - ) - - def do(self): - # Start the work and wait - self.__work_thread.start() - self.__done_or_aborted_event.wait() - - # Now, check the call aborter to see what happened. - # Option 1: The RPC was terminated. Kill the work thread and raise an exception - if self.call_aborter.must_abort(): - log.info( - "", "Aborting work in progress: %s", self.__runnable_func - ) - self.__work_thread.destroy() - self.__work_thread.join() - raise AbortedException("Aborted work: {}".format(self.__runnable_func)) - - # Options 2: Work thread finished normally. Hooray! - log.debug("Work finished: %s", self.__runnable_func) - self.__work_thread.join() - return self.__work_thread.get_or_throw() diff --git a/caikit/runtime/work_management/abortable_context.py b/caikit/runtime/work_management/abortable_context.py new file mode 100644 index 000000000..240dd0536 --- /dev/null +++ b/caikit/runtime/work_management/abortable_context.py @@ -0,0 +1,175 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from queue import SimpleQueue +from typing import Dict, Optional +import abc +import ctypes +import threading +import uuid + +# First Party +import alog + +# Local +from caikit.runtime.types.aborted_exception import AbortedException + +log = alog.use_channel("ABORT-ACTION") + + +class ActionAborter(abc.ABC): + """Simple interface to wrap up a notification that an action must abort. + + Children of this class can bind to any notification tool (e.g. grpc context) + """ + + @abc.abstractmethod + def must_abort(self) -> bool: + """Indicate whether or not the action must be aborted""" + + @abc.abstractmethod + def set_context(self, context: "AbortableContext"): + """Set the abortable context that must be notified to abort work""" + + @abc.abstractmethod + def unset_context(self): + """Unset any abortable context already held. Do not notify it that work should abort""" + + +class ThreadInterrupter: + """This class implements a listener which will observe all ongoing work in `AbortableContexts` + and raise exceptions in the working threads if they need to be aborted. + + The implementation spawns a single extra thread to wait on any contexts to abort, and + interrupt the thread that the context is running in. This keeps the total number of running + threads much smaller than using a new thread to monitor each AbortableContext. + """ + + _SHUTDOWN_SIGNAL = -1 + + def __init__(self): + # Using a SimpleQueue because we don't need the Queue's task api + self._queue = SimpleQueue() + self._thread: Optional[threading.Thread] = None + self._context_thread_map: Dict[uuid.UUID, int] = {} + self._start_stop_lock = threading.Lock() + + def start(self): + """Start the watch loop that will abort any registered contexts passed to .kill()""" + with self._start_stop_lock: + if self._thread and self._thread.is_alive(): + log.debug("ThreadInterrupter already started") + return + log.debug("Starting ThreadInterrupter") + self._thread = threading.Thread(target=self._watch_loop) + self._thread.start() + + def stop(self): + """Stop the watch loop""" + with self._start_stop_lock: + if self._thread and not self._thread.is_alive(): + log.debug("ThreadInterrupter already shut down") + return + + log.info("Stopping ThreadInterrupter") + self._queue.put(self._SHUTDOWN_SIGNAL) + self._thread.join(timeout=1) + + def register(self, context_id: uuid, thread: int) -> None: + self._context_thread_map[context_id] = thread + + def unregister(self, context_id: uuid) -> None: + self._context_thread_map.pop(context_id, None) + + def kill(self, context_id: uuid) -> None: + # Put this context onto the queue for abortion and immediately return + self._queue.put(context_id, block=False) + + def _watch_loop(self): + while True: + log.debug("Waiting on any work to abort") + context_id = self._queue.get() + + if context_id == self._SHUTDOWN_SIGNAL: + log.debug("Ending abort watch loop") + return + + self._kill_thread(context_id) + + # Ensure this context/thread pair is unregistered + self.unregister(context_id) + + def _kill_thread(self, context_id: uuid.UUID) -> bool: + thread_id = self._context_thread_map.get(context_id, None) + + if thread_id: + log.debug("Interrupting thread id: ", thread_id) + # This raises an AbortedException asynchronously in the target thread. (We can't just + # use raise, because this thread is the ThreadInterrupter's watch thread). + # The exception will only be raised once the target thread regains control of the + # python interpreter. This means that statements like `time.sleep(9999999)` cannot be + # interrupted in this manner. + async_exception_result = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), ctypes.py_object(AbortedException) + ) + if async_exception_result > 1: + log.warning("Failed to abort thread") + return False + + return True + + else: + log.warning("AbortableWork context already unregistered") + return False + + +class AbortableContext: + """Context manager for running work inside a context where it's safe to abort. + + This is a class instead of a `@contextmanager` function because __exit__ needs to + happen on exception. + """ + + def __init__(self, aborter: ActionAborter, interrupter: ThreadInterrupter): + """Setup the context. + The aborter is responsible for notifying this context if the work needs to be aborted. + The interrupter watches all such events, and kills the thread running in this context + if the aborter notifies it to abort.""" + self.aborter = aborter + self.interrupter = interrupter + + self.id = uuid.uuid4() + + def __enter__(self): + if self.aborter and self.interrupter: + log.debug4("Entering abortable context %s", self.id) + # Set this context on the aborter so that it can notify us when work should be aborted + self.aborter.set_context(self) + # Register this context with the interrupter so that it knows which thread to kill + thread_id = threading.get_ident() + self.interrupter.register(self.id, thread_id) + else: + log.debug4("Aborter or Interrupter was None, no abortable context created.") + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.aborter and self.interrupter: + # On any exit, whether an exception or not, we unregister with the interrupter + # This prevents the interrupter from aborting this thread once this context has ended + self.interrupter.unregister(self.id) + self.aborter.unset_context() + + def abort(self): + """Called by the aborter when this context needs to be aborted""" + if self.interrupter: + self.interrupter.kill(self.id) diff --git a/caikit/runtime/work_management/rpc_aborter.py b/caikit/runtime/work_management/rpc_aborter.py index 39f8df003..8ad951125 100644 --- a/caikit/runtime/work_management/rpc_aborter.py +++ b/caikit/runtime/work_management/rpc_aborter.py @@ -15,9 +15,6 @@ """ This module helps us know when an rpc call is cancelled, and we need to stop or undo work """ -# Standard -import threading - # Third Party import grpc @@ -25,8 +22,10 @@ import alog # Local -from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException -from caikit.runtime.work_management.abortable_action import ActionAborter +from caikit.runtime.work_management.abortable_context import ( + AbortableContext, + ActionAborter, +) log = alog.use_channel("CALL-ABORTER") @@ -38,17 +37,15 @@ class RpcAborter(ActionAborter): relinquished control anyway. The interesting case is when a client cancels a call or a deadline is hit, which could trigger this callback but will not interrupt the thread doing work. - In order to actually interrupt threads doing the work, events can be registered with an - instance of this class in order ton receive notification on RPC termination. + In order to actually interrupt threads doing the work, abortable contexts can be registered + with an instance of this class in order ton receive notification on RPC termination. IFF the RPC has been terminated, `must_abort` will return True. """ def __init__(self, context: grpc.ServicerContext): - # Create an event that we can use to check RPC termination self.is_terminated = False - # Add an empty list for condition variables that will be notified on termination - self.events = [] + self.context = None callback_registered = context.add_callback(self.__rpc_terminated) @@ -57,27 +54,22 @@ def __init__(self, context: grpc.ServicerContext): if not callback_registered: log.warning( "", - "Failed to register rpc termination callback, aborting rpc", - ) - raise CaikitRuntimeException( - grpc.StatusCode.ABORTED, - "Could not register RPC callback, call has likely terminated.", + "Failed to register rpc termination callback, call has likely terminated", ) + self.is_terminated = True def must_abort(self): return self.is_terminated - def add_event(self, event: threading.Event): - self.events.append(event) - - # Sanity check: If we have already terminated, notify anything waiting on this condition + def set_context(self, context: AbortableContext): + self.context = context if self.must_abort(): - event.set() + self.context.abort() + + def unset_context(self): + self.context = None def __rpc_terminated(self): - # First set the flag so anybody waiting on us knows that gRPC wants us to abort work self.is_terminated = True - - # Then notify everybody waiting on us - for event in self.events: - event.set() + if self.context: + self.context.abort() diff --git a/caikit/version.py b/caikit/version.py index e88d411d4..2b1603c68 100644 --- a/caikit/version.py +++ b/caikit/version.py @@ -1,7 +1,6 @@ -# pylint: disable=unused-import try: # Local - from ._version import __version__, __version_tuple__ + from ._version import __version__, __version_tuple__ # noqa: F401 # unused import except ImportError: __version__ = "unknown" version_tuple = (0, 0, __version__) diff --git a/caikit_health_probe/__init__.py b/caikit_health_probe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/caikit_health_probe/__main__.py b/caikit_health_probe/__main__.py new file mode 100644 index 000000000..3a12337c4 --- /dev/null +++ b/caikit_health_probe/__main__.py @@ -0,0 +1,358 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module implements common health probes (liveness and readiness) for all +running runtime servers. +""" +# Standard +from contextlib import contextmanager +from typing import List, Optional, Tuple +import importlib.util +import os +import sys +import tempfile +import warnings + +# Third Party +import psutil + +# First Party +import alog + +# We play some tricks to import caikit's config here and avoid importing all of +# caikit itself. The reason for this is that importing caikit can be very costly +# relative to the actual cost of probing the servers and since this is intended +# to stand alone as an executable, that import time cost is added to every call. +caikit_spec = importlib.util.find_spec("caikit") +sys.path = [os.path.dirname(caikit_spec.origin)] + sys.path +# Third Party +from config import get_config + +sys.path = sys.path[1:] + + +log = alog.use_channel("PROBE") + + +@alog.timed_function(log.debug) +def readiness_probe() -> bool: + """Run a readiness probe against all running runtime servers. + + This function is intended to be run from an environment where the config is + identical to the config that the server is running such as from inside a + kubernetes pod where the server is also running. + + Returns: + ready (bool): True if all servers are ready to take requests, False + otherwise + """ + + # Get TLS key/cert files if possible + config = get_config() + tls_key = config.runtime.tls.server.key + tls_cert = config.runtime.tls.server.cert + client_ca = config.runtime.tls.client.cert + http_ready, grpc_ready = None, None + + if config.runtime.http.enabled: + log.debug("Checking HTTP server health") + http_ready = _http_readiness_probe( + config.runtime.http.port, tls_key, tls_cert, client_ca + ) + + if config.runtime.grpc.enabled: + log.debug("Checking gRPC server health") + grpc_ready = _grpc_readiness_probe( + config.runtime.grpc.port, tls_key, tls_cert, client_ca + ) + + if False in [http_ready, grpc_ready]: + log.info( + "", + "Runtime server(s) not ready. HTTP: %s, gRPC: %s", + http_ready, + grpc_ready, + ) + return False + return True + + +@alog.timed_function(log.debug) +def liveness_probe(runtime_proc_identifier: str = "caikit.runtime") -> bool: + # Get all running processes that we have access to + this_proc = psutil.Process() + this_exe = this_proc.exe() + procs = [_get_proc_info(pid) for pid in psutil.pids() if pid != this_proc.pid] + + # Filter down to caikit runtime processes + caikit_procs = [ + proc_info + for proc_info in procs + if proc_info is not None + and proc_info[0] == this_exe + and any(runtime_proc_identifier in arg for arg in proc_info[1]) + ] + + # If we have running caikit processes, we consider the server to be alive + return bool(caikit_procs) + + +## Implementation ############################################################## + + +def _get_proc_info(pid: int) -> Optional[Tuple[str, List[str]]]: + """Attempt to get the given pid's information (exe and cmdline)""" + try: + proc = psutil.Process(pid) + return (proc.exe(), proc.cmdline()) + except psutil.Error: + return None + + +def _http_readiness_probe( + port: int, + tls_key: Optional[str], + tls_cert: Optional[str], + client_ca: Optional[str], +) -> bool: + """Probe the http server + + The implementation of this utility is a bit tricky because mTLS makes this + quite challenging. For insecure or TLS servers, we expect a valid ready + response, but for mTLS servers, we may not have a valid key/cert pair that + the client can present to the server that is signed by the expected CA if + the trusted client CA does not match the one that signed the server's + key/cert pair. + + The workaround for this is to detect SSLError and consider that to be a + passing readiness check. If the server is ready enough to _reject_ bad SSL + requests, it's ready enough to server good ones! + + Args: + port (int): The port that the HTTP server is serving on + tls_key (Optional[str]): Body or path to the TLS key file if TLS/mTLS + enabled + tls_cert (Optional[str]): Body or path to the TLS cert file if TLS/mTLS + enabled + client_ca (Optional[str]): The client ca cert that the server is using + for mutual client auth + + Returns: + ready (bool): True if the http server is ready to take requests, False + otherwise + """ + # NOTE: Local imports for optional dependency + with alog.ContextTimer(log.debug2, "Done with local grpc imports: "): + + # Third Party + import requests # pylint: disable=import-outside-toplevel + + # Requests requires that the TLS information be in files + with _tls_files(tls_key, tls_cert) as tls_files: + key_file, cert_file = tls_files + if key_file and cert_file: + protocol = "https" + kwargs = {"verify": False} + if client_ca: + log.debug("Probing mTLS HTTP Server") + kwargs["cert"] = (key_file, cert_file) + else: + log.debug("Probing TLS HTTP Server") + else: + log.debug("Probing INSECURE HTTP Server") + protocol = "http" + kwargs = {} + + try: + # Suppress insecure connection warnings since we disable server + # verification. This is ok since this probe will be run against + # localhost in a pod where the server is _known_ to be authentic. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", module="urllib3") + session = requests.Session() + retries = requests.adapters.Retry(total=0) + session.mount( + f"{protocol}://", requests.adapters.HTTPAdapter(max_retries=retries) + ) + # NOTE: Not using the constant to avoid big imports + resp = session.get( + f"{protocol}://localhost:{port}/health", + timeout=get_config().runtime.http.probe_timeout, + **kwargs, + ) + resp.raise_for_status() + return True + except requests.exceptions.SSLError as err: + log.debug("Got SSLError indicating a healthy SSL server! %s", err) + log.debug2(err, exc_info=True) + return True + except Exception as err: # pylint: disable=broad-exception-caught + log.debug2("Caught unexpected error: %s", err, exc_info=True) + return False + + +def _grpc_readiness_probe( + port: int, + tls_key: Optional[str], + tls_cert: Optional[str], + client_ca: Optional[str], +) -> bool: + """Probe the grpc server + + Since the gRPC server trusts its own cert for client verification, we can + make a valid readiness probe against the running server regardless of (m)TLS + config. + + Args: + port (int): The port that the gRPC server is serving on + tls_key (Optional[str]): Body or path to the TLS key file if TLS/mTLS + enabled + tls_cert (Optional[str]): Body or path to the TLS cert file if TLS/mTLS + enabled + client_ca (Optional[str]): The client ca cert that the server is using + for mutual client auth + + Returns: + ready (bool): True if the grpc server is ready to take requests, False + otherwise + """ + # NOTE: Local imports for optional dependency + with alog.ContextTimer(log.debug2, "Done with local grpc imports: "): + + # Third Party + from grpc_health.v1 import ( # pylint: disable=import-outside-toplevel + health_pb2, + health_pb2_grpc, + ) + import grpc # pylint: disable=import-outside-toplevel + + # Server hostname to use unless using socket mode + hostname = f"localhost:{port}" + socket_file = get_config().runtime.grpc.unix_socket_path + + # If available, use a unix socket + if socket_file and os.path.exists(os.path.dirname(socket_file)): + socket_address = f"unix://{socket_file}" + log.debug("Probing gRPC server over unix socket: %s", socket_file) + channel = grpc.insecure_channel(socket_address) + + elif tls_key and tls_cert: + tls_server_key = bytes(_load_secret(tls_key), "utf-8") + tls_server_cert = bytes(_load_secret(tls_cert), "utf-8") + if client_ca: + log.debug("Probing mTLS gRPC server") + credentials = grpc.ssl_channel_credentials( + root_certificates=tls_server_cert, + private_key=tls_server_key, + certificate_chain=tls_server_cert, + ) + else: + log.debug("Probing TLS gRPC server") + credentials = grpc.ssl_channel_credentials( + root_certificates=tls_server_cert, + ) + + # NOTE: If the server's certificate does not have 'localhost' in it, + # this will cause certificate validation errors and fail. The original + # workaround for this was to parse the cert's SANs and use hostname + # overrides, but that requires a full cryptographic PEM parser which + # is a security-sensitive dependency to pull that we want to avoid. + # Instead, the workaround is to use the unix socket server option + # above. + channel = grpc.secure_channel(hostname, credentials=credentials) + else: + log.debug("Probing INSECURE gRPC server") + channel = grpc.insecure_channel(hostname) + + client = health_pb2_grpc.HealthStub(channel) + try: + client.Check( + health_pb2.HealthCheckRequest(), + timeout=get_config().runtime.grpc.probe_timeout, + ) + return True + except Exception as err: # pylint: disable=broad-exception-caught + log.debug2("Caught unexpected error: %s", err, exc_info=True) + return False + + +@contextmanager +def _tls_files( + tls_key: Optional[str], + tls_cert: Optional[str], +) -> Tuple[Optional[str], Optional[str]]: + """Get files for the TLS key/cert if given""" + if not tls_key or not tls_cert: + yield None, None + return + valid_file_vals = [ + ((os.path.exists(fname) and fname) or None) for fname in [tls_key, tls_cert] + ] + if all(valid_file_vals): + yield tls_key, tls_cert + return + with tempfile.TemporaryDirectory() as workdir: + key_file, cert_file = valid_file_vals + if not key_file: + key_file = os.path.join(workdir, "tls.key") + with open(key_file, "w", encoding="utf-8") as handle: + handle.write(tls_key) + if not cert_file: + cert_file = os.path.join(workdir, "tls.cert") + with open(cert_file, "w", encoding="utf-8") as handle: + handle.write(tls_cert) + yield key_file, cert_file + return + + +def _load_secret(secret: str) -> str: + """NOTE: Copied from grpc_server to avoid costly imports""" + if os.path.exists(secret): + with open(secret, "r", encoding="utf-8") as secret_file: + return secret_file.read() + return secret + + +## Main ######################################################################## +def main(): + caikit_config = get_config() + alog.configure( + default_level=caikit_config.log.level, + filters=caikit_config.log.filters, + thread_id=caikit_config.log.thread_id, + formatter=caikit_config.log.formatter, + ) + + # Pull the probe type from the command line, defaulting to readiness + probe_type_map = { + "readiness": readiness_probe, + "liveness": liveness_probe, + } + probe_type = "readiness" + probe_args = [] + if len(sys.argv) > 1: + probe_type = sys.argv[1] + if len(sys.argv) > 2: + probe_args = sys.argv[2:] + log.debug("Probe type: %s", probe_type) + log.debug("Probe args: %s", probe_args) + probe_fn = probe_type_map.get(probe_type.lower()) + assert probe_fn is not None, f"Invalid probe type: {probe_type}" + + if not probe_fn(*probe_args): + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/docs/adrs/021-http-server.md b/docs/adrs/021-http-server.md index 8a83d766a..f52def97b 100644 --- a/docs/adrs/021-http-server.md +++ b/docs/adrs/021-http-server.md @@ -4,7 +4,7 @@ Up until now, caikit has primarily managed its server runtime using gRPC. This d As caikit moves to being a general-purpose framework for managing production-grade model definitions and operations, supporting a native REST server has become increasingly critical to enable the "15-minutes to value" experience. Prior to this ADR, REST has been supported using grpc-gateway-wrapper which has not been sufficiently flexible for all usecases (including streaming). -Finally, the most popular open model hosting framework at the time of this ADR is huggingface (https://huggingface.co/). HF offers a [REST API](huggingface.co/docs/api-inference/detailed_parameters) for its task-specific inference endpoints. The current consumers of caikit are strategically aligned with HF for open-model collaboration, and therefore caikit needs to provide aligned APIs with the HF task inference endpoints. +Finally, the most popular open model hosting framework at the time of this ADR is huggingface (https://huggingface.co/). HF offers a [REST API](https://huggingface.co/docs/api-inference/detailed_parameters) for its task-specific inference endpoints. The current consumers of caikit are strategically aligned with HF for open-model collaboration, and therefore caikit needs to provide aligned APIs with the HF task inference endpoints. ## Decision @@ -20,4 +20,4 @@ Accepted * Users will be able to select running `caikit.runtime` with an HTTP server, a gRPC server, or both to best match their application usecases * The `caikit` ecosystem will be able to align to Hugging Face task inference APIs * Authors of `caikit` data model objects and modules will need to work through compatibility issues with both HTTP and gRPC server interface deduction -* There will be a new set of python dependencies to support the HTTP stack that need to be managed \ No newline at end of file +* There will be a new set of python dependencies to support the HTTP stack that need to be managed diff --git a/docs/adrs/023-info-service-endpoint.md b/docs/adrs/023-info-service-endpoint.md new file mode 100644 index 000000000..37e84e856 --- /dev/null +++ b/docs/adrs/023-info-service-endpoint.md @@ -0,0 +1,27 @@ +# ADR N023: Info Service Endpoint + +Some users of `caikit` find it useful to be able to easily retrieve versioning information to help with debugging. This includes details on the version of caikit libraries as well as the runtime image. + +## Decision + +This proposal is to add an info servicer in order to create an endpoint in both the gRPC and HTTP server in order to get version information. The info servicer will be generic so it can be used for other information requests in the future such as model management metadata. This will return a dictionary of version information for caikit libraries by default and additionally python packages and the runtime image if set. + +For HTTP the endpoint will be something like `/info/version` and for gRPC it will be something like `GetRuntimeInfo()`. + +**Example Config** +```yaml +runtime: + version_info: + python_packages: + all: false + runtime_image: "" +``` + +## Status + +choose one: Accepted + +## Consequences + +* Users can hit an info endpoint that will provide caikit versioning details as well as versions for other python packages. +* New `config` section `runtime.version_info` that will control what version information is retrieved to add to the endpoint. diff --git a/docs/adrs/024-remote-module-invocation.md b/docs/adrs/024-remote-module-invocation.md new file mode 100644 index 000000000..bbe544054 --- /dev/null +++ b/docs/adrs/024-remote-module-invocation.md @@ -0,0 +1,282 @@ +# ADR 024 Remote Module Invocation + +Currently, caikit only supports locally loaded modules and does not understand a remote runtime +server. Remote servers can allow for a better distribution of resources and more complex runtime +architectures. Remote modules could also form the basis of a ["thin client"](https://github.com/caikit/caikit/issues/255) +where a service can run remote invocations without installing library dependencies. +Library users have started creating custom initializers to solve these problems; however, +those implementations depend on specific assumptions about the caikit runtime that might change +from version to version. + +## Decision + +We propose implementing a system for remote model invocation directly in the Caikit core library. +This system will handle discovering and describing models from a remote runtime. These remote modules +will look and function like local modules, except any task or train invocation is forwarded to the +remote. The core components of this system will be a new ModelFinder named RemoteModelFinder and a +new ModelInitializer called RemoteModelInitializer. + +The RemoteModelFinder will gather the remote server's connection and model information. In the first +iteration, the RemoteModelFinder will use the collected information to find a locally available +ModuleBase and construct a new RemoteModuleConfig. By design, the RemoteModuleConfig does not +contain any direct references to the imported Module and uses CaikitMethodSignatures to describe the +methods and tasks. This allows future implementations or other sources to construct RemoteModuleConfigs +without necessarily having to import the local ModuleBase. To adequately describe these methods, +some assumptions are made about the Caikit service generation, specifically around the dataclass +and rpc naming schemes. + +The RemoteModuleConfig is then passed to the RemoteModelInitializer, which will construct a new +RemoteModuleInstance with the same methods, signatures, and parameters as the source Module without +using `caikit.runtime` or references to the original. This module will be constructed with the same +`@module` and `@task.taskmethod` decorators to ensure the module acts precisely as a locally loaded +module. One issue possible issue for the future is that the RemoteModelInstance relies on +dataclasses created during service generation. + + + +### RemoteModelFinder + + +```yaml +model_management: + finders: + : + type: REMOTE + config: + connection: + hostname: str + port: int + protocol: Optional[str]="grpc" + tls: + enabled: Optional[bool]=False + ca_file: Optional[str]=None + cert_file: Optional[str]=None + key_file: Optional[str]=None + options: Optional[Dict[str,str]]={} + discover_models: Optional[bool]=True + supported_models: Optional[Dict[str, str]]={} + : +``` + +The proposed configuration for the RemoteModelFinder is above. The only required field is the +generic `connection` dictionary that supports a secure channel, mutual TLS, and custom GRPC/HTTP +options. The `connection.hostname` setting contains the remote's hostname, while `connection.port` determines the +runtime port. The optional `connection.protocol` config is used to select which protocol to send +requests over, with the default being `grpc`. The `connection.tls` dictionary contains all information +related to TLS with `tls.enabled` controlling if the server is running SSL, `tls.ca_file` is the path to +the CA file that the remote's certificate is signed by, `tls.cert_file` is the path to the +MTLS client certificate to be sent with the request, and finally, `tls.key_file` which is the file +containing the MTLS client key. The final connection config is `connection.options` which defines a +list of options to pass to either the HTTP or GRPC request; for an example of options, take a look +at the [GRPC Channel options](https://grpc.github.io/grpc/core/group__grpc__arg__keys.html#details) + + +Two additional optional fields help control what models this remote supports. The +`discover_models` setting is a boolean that controls if the finder should query the remote runtime +to dynamically discover what models are loaded and their corresponding `module_id`'s. The +`supported_models` config is a dictionary that contains a static mapping of model_paths to module_ids +that the remote supports. The `supported_models` setting is required to add support for remotes that +don't have a reflection api or ones that lazily load their models (like ModelMesh). + + +To help illustrate the above config, we included some pseudo python code to illustrate what happens +during model finding: +```python +def find_model(model_path: str)->RemoteModuleConfig: + # Check if model_path is in static mapping + if model_path in config.supported_models: + local_module = module_registry().get(config.supported_models[model_path]) + + # Check if model can be discovered dynamically + elif discover_models: + remote_model_mapping = gather_remote_model_map(config.connection) + local_module = remote_model_mapping.get(model_path) + + if not local_module: + raise CaikitCoreException("Model not found") + + # Construct config for use by the RemoteModelInitializer. This function is + # described down below in the #RemoteModuleConfig section + return generate_config_for_module(local_module, config.connection, model_path) +``` + +## RemoteModelInitializer + +```yaml +model_management: + initializers: + : + type: REMOTE +``` + +The proposed configuration for the RemoteModelInitializer is above. The initializer does not take in +any global configuration settings, as the remote information will be passed in via the +RemoteModelFinder. If a system is expected to have both local and remote models, consider using a +MultiModelInitializer to handle both use cases. + +To help illustrate how the RemoteModelInitializer would initialize a Module, we provided a snippet +of pseudo Python code: +```python3 +def init(model_config: RemoteModuleConfig)->ModuleBase: + # Construct empty RemoteModule Instance + @module( + id=model_config.id, + name=model_config.name, + version=model_config.name, + task=[task_tuple[0] for task_tuple in model_config.task_methods] + ) + class _RemoteModelInstance(RemoteModelBase): + pass + + # Add all task methods to the RemoteModel class + for task, inference_methods in model_config.task_methods: + for method in inference_methods: + infer_func = partial(_RemoteModelInstance.remote_method_request, method=method) + task_wrapped_func = task.taskmethod(infer_func) + setattr(_RemoteModelInstance, method.signature.name, task_wrapped_func) + + # Add train method to class if one exists + if model_config.train_method: + train_func = partial(_RemoteModelInstance.remote_method_request, method=model_config.train_method) + setattr(_RemoteModelInstance, model_config.train_method.signature.name, train_func) + + # Return Model Instance + return _RemoteModelInstance(model_config.connection, model_config.model_path) + + +class RemoteModelBase(ModuleBase): + def __init__(self, connection: Dict[str, Any], model_path: str): + ... + + def remote_method_request(self, method: RemoteMethodRpc, *args, **kwargs): + # Run the remote invocation using the information defined in the RemoteMethodRpc + if self.connection.protocol == "grpc": + + elif self.connection.protocol == "http": + +``` + +## RemoteModuleConfig + +```yaml +class RemoteModuleConfig(ModuleConfig): + # Remote runtime information copied from the RemoteModelFinder config + connection: Dict[str, Any] + + # Method information + # use list and tuples instead of a dictionary to avoid aconfig.Config error + task_methods: List[Tuple[type[TaskBase], List[RemoteMethodRpc]]] + train_method: RemoteMethodRpc + + # Source Module Information + module_id: str + module_name: str + model_path: str + +@dataclass +class RemoteMethodRpc: + # full signature for this RPC + signature: CaikitMethodSignature + + # Request and response objects for this RPC + request_dm_name: str + response_dm_name: str + + # Either the function name of the GRPC Servicer or HTTP endpoint + rpc_name: str + + # Only used for infer RPC types + input_streaming: bool + output_streaming: bool +``` + +The RemoteModuleConfig is a custom subclass of ModuleConfig, which contains a description of a Module's +tasks, inference and train methods, and version information, as well as the connection information of +the remote runtime. The combination of the two allows the RemoteModelInitializer to construct a new +RemoteModule without having to import `caikit.runtime` or the source model. To help simplify the +config definition and access, we decided to create a helper dataclass, `RemoteMethodRpc`, which contains +information about a specific method and includes things like the CaikitMethodSignature, +request&response DataModel names, and the remote RPC name. The `RemoteMethodRpc` dataclass contains +all the runtime-specific assumptions and is the main point where overlap happens. The +`RemoteModelFinder`, which constructs `RemoteMethodRpc`s, should not import `caikit.runtime`; +however, it will re-use a lot of code around the different naming schemes. + +We created the following pseudocode to help illustrate how the `RemoteModelFinder` constructs a +`RemoteModuleRpc`. (Note this is the function used in the pseudocode for the [RemoteModelFinder](#remotemodelfinder)) + +```python +def generate_config_for_module(module: ModuleBase, connection_info: Dict[str, Any], model_path: str)->RemoteModuleConfig: + # Gather a description of all tasks and their associated methods + task_methods = [] + for task_class in module.tasks: + task_functions = [] + for input, output, signature in module_class.get_inference_signatures(task_class): + # Construct request_dm_name, task_request_name, and rpc_name. This makes assumptions + # about caikit.runtime service generation + request_dm_name ~= "TaskRequest" + task_request_name ~= "TaskPredict" + rpc_name = "/api/v1/TaskPredict" + + task_functions.append( + RemoteMethodRpc( + signature=signature, + request_dm_name=request_dm_name, + response_dm_name=signature.return_type.__name__, + rpc_name=rpc_name, + input_streaming=input, + output_streaming=output, + ) + ) + + task_functions.append((task_class, task_functions)) + + # Gather description of the Train functions + train_method = None + if module.TRAIN_SIGNATURE: + request_dm_name ~= "TrainRequest" + rpc_name ~= "/api/v1/TaskTrain" + + train_method = RemoteMethodRpc( + signature=module.TRAIN_SIGNATURE, + request_dm_name=request_dm_name, + response_dm_name=module.TRAIN_SIGNATURE.return_type.__name__, + rpc_name=rpc_name, + ) + + # Construct the remote config + return RemoteModuleConfig( + { + # Connection info + "connection":connection_info, + # Method info + "task_methods":task_methods, + "train_method":train_method, + # Source Module Information + "model_path": model_path, + "module_id": module.MODULE_ID, + "module_name": module.MODULE_NAME, + } + ) + +``` + +### Diagram + +image +This is an updated block diagram of the various model loading components and their relationships. + + +## Status + +choose one: Accepted + +if deprecated, include a rationale. + +If superseded, include a link to the new ADR + + +## Consequences + +- Library users will be able to configure remote runtime servers. +- Multiple caikit runtimes can work together to serve a large set of models. +- When updating or changing the service generation, the `RemoteModelFinder` will also have to be changed. diff --git a/examples/sample_lib/README.md b/examples/sample_lib/README.md index 6af92ae34..6c8d0acfb 100644 --- a/examples/sample_lib/README.md +++ b/examples/sample_lib/README.md @@ -1,26 +1,164 @@ +**Table of contents**√ +- [Interacting with the Sample lib](#interacting-with-the-sample-lib) + - [Build and start a runtime server with sample\_lib](#build-and-start-a-runtime-server-with-sample_lib) + - [Interact using the python client](#interact-using-the-python-client) + - [Interact using terminal](#interact-using-terminal) + - [To train a model](#to-train-a-model) + - [Using gRPC](#using-grpc) + - [Using HTTP](#using-http) + - [To check on training status for a training](#to-check-on-training-status-for-a-training) + - [Using gRPC](#using-grpc-1) + - [Using HTTP](#using-http-1) + - [To call inference on a model with model Id](#to-call-inference-on-a-model-with-model-id) + - [To use the gRPC Server for inference](#to-use-the-grpc-server-for-inference) + - [To use the REST Server for inference](#to-use-the-rest-server-for-inference) + - [Interact using a combination of pb2s and DataModels](#interact-using-a-combination-of-pb2s-and-datamodels) + # Interacting with the Sample lib -Run `python3 -m examples.sample_lib.start_runtime_with_sample_lib` in one terminal. +This document describes how to quickly get a runtime server built with `sample_lib` library, train a model with gRPC and with that trained model, send an inference call to the server with either HTTP or gRPC call. + +## Build and start a runtime server with sample_lib + +Run the `start_runtime_with_sample_lib` python script: + +```shell +python3 -m examples.sample_lib.start_runtime_with_sample_lib +``` + +This will setup a config with both `grpc` and `http` servers enabled for inference and training. The script then starts the `caikit runtime server`. While the server is running, you can see the generated proto files in a directory called `protos`. (They will be auto-deleted once you kill the server) + +We generate 3 services total: +- A `train` service. The proto for this service is `protos/samplelibtrainingservice.proto` +- An `inference` service. The proto for this service is `protos/samplelibservice.proto` +- A `training management` service. The proto for this service is `protos/trainingmanagement.proto` + +You can now leave the server running and open a new terminal to proceed with next steps to train a model, check its training status and send an inference request to your model. + +(To kill the server, press Ctrl + C. This will remove the `protos` directory to clean up.) + +## Interact using the python client + +You can run the python client using: + +```shell +python3 -m examples.sample_lib.client +``` + +The python client sends in requests to all 3 services that were mentioned above, printing the result from each request. + +## Interact using terminal + +You can also use `grpcurl` (for gRPC requests) or `curl` (for http requests) to send in commands one-by-one to all the 3 services that were mentioned above. + +Note: `http` does not currently support `training management` APIs. +### To train a model + +#### Using gRPC + +In order to train a model via gRPC, we will use `grpcurl` and point the import-path to `protos` dir, then call one of the Train rpc's available in the `SampleLibTrainingService` (see `protos/samplelibtrainingservice.proto` file generated above for all Train rpcs): + +```shell +grpcurl -plaintext -import-path protos/ -proto samplelibtrainingservice.proto -d '{"model_name": "my_model", "parameters": {"training_data": {"file": {"filename": "protos/sample.json"}}}}' localhost:8085 caikit_sample_lib.SampleLibTrainingService/SampleTaskSampleModuleTrain +``` + +You should receive a response similar to the below: + +```shell +{ + "trainingId": "wTHxlsu:5bdb5949-4efa-4512-bbac-709cbf37c00e", + "modelName": "my_model" +} +``` + +Copy the `trainingId` to use in next step. + +#### Using HTTP + +Docs coming soon... + +### To check on training status for a training + +#### Using gRPC + +With a `trainingId`, you can get a training status via gRPC. Replace the command below with your `trainingId`. + +```shell +grpcurl -plaintext -import-path protos/ -proto trainingmanagement.proto -d '{"training_id": ""}' localhost:8085 caikit.runtime.training.TrainingManagement/GetTrainingStatus +``` + +You should get a response like this: + +```shell +{ + "trainingId": "wTHxlsu:5bdb5949-4efa-4512-bbac-709cbf37c00e", + "state": "COMPLETED", + "submissionTimestamp": "2023-08-30T22:19:13.739694Z", + "completionTimestamp": "2023-08-30T22:19:13.744542Z" +} +``` +Once your training is completed, you can proceed to call inference on the model. + +#### Using HTTP + +`http` currently doesn't support training status APIs. Coming soon... + +### To call inference on a model with model Id + +You are now ready to call inference via either gRPC or REST. + +#### To use the gRPC Server for inference + +You can also use the gRPC Server to call inference on this model by running: +```shell +grpcurl -plaintext -import-path protos/ -proto samplelibservice.proto -d '{"sample_input": {"name": "world"}}' -H 'mm-model-id: my_model' localhost:8085 caikit_sample_lib.SampleLibService/SampleTaskPredict +``` + +You should receive a successful response back with a response body: +```shell +{ + "greeting": "Hello world" +} +``` -The following sections show how to interact with the server using either REST or gRPC. +#### To use the REST Server for inference -## REST +- In a browser of choice, visit `http://localhost:8080/docs/`. All the available inference rpcs are listed. Expand on the correct task for the model you trained. In this example, we are using `api/v1/{model_id}/task/sample`. +- Click "Try It Out" +- Fill in model_id "my_model" as used in the train a model step. Change the request body to your liking. Then click "Execute". Ex: -Head over to `localhost:8080/docs`. First train a module then try inferencing on it. (Training management API is not yet supported on REST) +```shell +curl -X 'POST' \ + 'http://localhost:8080/api/v1/task/sample' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{"inputs": {"name": "world"}, "model_id": "my_model"}' +``` -## GRPC +You should receive a 200 response back with a response body: +```shell +{ + "greeting": "Hello IBM" +} +``` -The server should automatically create a `protos` dir. CD into that, `cd protos`, then try the following commands. +## Interact using a combination of pb2s and DataModels -### Training +Install `protoc`, -```grpcurl -plaintext -proto samplelibtrainingservice.proto -d '{"model_name": "my_model", "training_data": {"file": {"filename": "protos/sample.json"}}}' localhost:8085 caikit.runtime.SampleLib.SampleLibTrainingService/SampleTaskSampleModuleTrain``` +```shell +pip3 install grpcio-tools +`````` -### Training status +then generate the compiled `pb2` files, -```grpcurl -plaintext -proto trainingmanagement.proto -d '{"training_id": "=0.14.1,<0.16.0", "grpcio>=1.35.0,<2.0,!=1.55.0", "ijson>=3.1.4,<3.3.0", + "importlib-metadata>=6.8.0,<8.0.0", "munch>=2.5.0,<5.0", "numpy>=1.22.2,<2", "protobuf>=3.19.0,<5", + "psutil>=5,<6", "py-to-proto>=0.5.0,<0.6.0,!=0.2.1", "PyYAML>=6.0,<7.0", "semver>=2.13.0,<4.0", @@ -29,6 +31,12 @@ dependencies = [ "werkzeug>=2.3.7,<4.0.0" ] +[project.scripts] + +caikit-runtime = "caikit.runtime.__main__:main" +caikit-health-probe = "caikit_health_probe.__main__:main" +caikit-render-interfaces = "caikit.runtime.dump_services:main" + [project.optional-dependencies] ## Runtime Extra Sets ## @@ -42,6 +50,7 @@ runtime-grpc = [ runtime-http = [ "fastapi[all]>=0.100,<1", + "requests>=2.28.2,<3", "sse-starlette>=1.6.1,<2", ] @@ -49,33 +58,44 @@ interfaces-vision = [ "pillow>=6.2.1,<11.0" ] +interfaces-ts = [ + "pandas>=1.4.3,<2", +] + +interfaces-ts-pyspark = [ + "caikit[interfaces-ts]", + "pyspark>=3.3,<3.6", + "pyarrow>=8.0.0,<15" +] + # NOTE: This is "all" from the user perspective, not the dev perspective all = [ - "caikit[runtime-grpc, runtime-http, interfaces-vision]", + "caikit[runtime-grpc, runtime-http, interfaces-vision, interfaces-ts]", ] ## Dev Extra Sets ## dev-test = [ - "pytest-asyncio>=0.21.0,<1", + # NOTE: pytest-asyncio>=0.22 breaks importing with an error about multiple + # imports of sample modules + "pytest-asyncio>=0.21.0,<0.22", "pytest-cov>=2.10.1,<5.0", "pytest-html>=3.1.1,<5.0", "pytest>=6.2.5,<8.0", - "requests>=2.28.2,<3", "tls_test_tools>=0.1.1", "wheel>=0.38.4", - "caikit[interfaces-vision]", + "caikit[interfaces-vision, interfaces-ts-pyspark]", ] dev-docs = [ "sphinx>=4.0.2,<8.0", "sphinx-autoapi>=2.1.0", - "sphinx-rtd-theme>=1.2.1,<1.4.0", + "sphinx-rtd-theme>=1.2.1,<2.1.0", ] dev-fmt = [ + "ruff==0.1.11", "pre-commit>=3.0.4,<4.0", - "pylint>=2.16.2,<4.0", "pydeps>=1.12.12,<2", ] @@ -83,6 +103,14 @@ dev-build = [ "flit==3.9.0", ] +dev-proto3 = [ + "caikit[all-dev]", + "protobuf>=3.19.0,<3.20", + "grpcio>=1.35.0,<1.49", + "grpcio-health-checking>=1.35.0,<1.49", + "grpcio-reflection>=1.35.0,<1.49", +] + # NOTE: This is "all" from the user and dev perspective all-dev = [ "caikit[all, dev-test, dev-docs, dev-fmt, dev-build]" @@ -101,4 +129,70 @@ Source = "https://github.com/caikit/caikit" [tool.pytest.ini_options] markers = [ "examples: marks tests as e2e examples (deselect with '-m \"not examples\"')", + "slow: marks tests requiring pyspark be installed (deselect with '-m \"not slow\"')" +] +filterwarnings = [ + "ignore:distutils Version classes are deprecated.*:DeprecationWarning", + "ignore:np.find_common_type is deprecated.*:DeprecationWarning", + "ignore:Converting `np.character` to a dtype is deprecated.*:DeprecationWarning", +] + +[tool.ruff] +line-length = 100 +target-version = "py38" +exclude = ["caikit/runtime/protobufs/*.py"] + + +[tool.ruff.lint] +select = [ "E", "F", "UP", "B", "SIM", "I"] +ignore = [ + "UP032", # f-string + "UP034", # extraneous-parentheses + # "UP035", # deprecated-import + + ## original errors fromt pylint + "F403", # unable to detect undefined names + "I001", # import block unsorted/unformatted + "E402", # module level import not at top of file + # "B028", # warnings: no explicit stacklevel keyword argument found + # "I0001", # raw-checker-failed + # "I0010", # bad-inline-option + # "I0011", # locally-disabled + # "I0013", # file-ignored + # "I0020", # suppressed-message + # "I0021", # useless-suppression + # "I0022", # deprecated-pragma + + ## added messages in caikit + # "I0023", # use-symbolic-message-instead + # "C0103", # invalid-name + # "C0115", # missing-class-docstring + # "C0114", # missing-module-docstring + # "C0116", # missing-function-docstring + # "C0209", # consider-using-f-string + # "R1710", # inconsistent-return-statements + # "E1101", # no-member + # "R0913", # too-many-arguments + # "R0914", # too-many-locals + # "R0912", # too-many-branches + # "R0915", # too-many-statements + # "R0401", # cyclic-import + # "R0903", # too-few-public-methods + # "W0212", # protected-access + # "W0511", # fixme + # "W1202", # logging-format-interpolation + # "E1205", # logging-too-many-args + # "W0201", # attribute-defined-outside-init + # "W0223", # abstract-method + # "W0104", # pointless-statement + # "C0411", # wrong-import-order +] + +[tool.ruff.per-file-ignores] +"__init__.py" = [ + "F401", # imported but unused + "F403" # unable to detect undefined names +] +"caikit/runtime/service_generation/protoable.py" = [ + "SIM114", # Combine `if` branches using logical `or` operator # TODO: simplify this logic ] diff --git a/tests/conftest.py b/tests/conftest.py index d9cdf751f..3e5716766 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,14 +7,16 @@ from typing import Callable, List, Union from unittest.mock import patch import copy -import json +import importlib import os +import platform import sys import tempfile import uuid # Third Party import pytest +import semver # First Party import alog @@ -38,6 +40,10 @@ "fixtures", ) +# Some tests need to be skipped if using protobuf 3.X and arm +PROTOBUF_VERSION = semver.parse(importlib.metadata.version("protobuf"))["major"] +ARM_ARCH = "arm" in platform.machine() + # Make sample_lib available for import sys.path.append(FIXTURES_DIR) # Local @@ -92,6 +98,11 @@ def box_model_path() -> str: return os.path.join(FIXTURES_DIR, "models", "box") +@pytest.fixture +def primitive_model_path() -> str: + return os.path.join(FIXTURES_DIR, "models", "primitive") + + @pytest.fixture def streaming_model_path() -> str: return os.path.join(FIXTURES_DIR, "dummy_streaming_module") diff --git a/tests/core/data_model/streams/test_data_stream.py b/tests/core/data_model/streams/test_data_stream.py index d3756e7d3..1eed24a35 100644 --- a/tests/core/data_model/streams/test_data_stream.py +++ b/tests/core/data_model/streams/test_data_stream.py @@ -167,6 +167,26 @@ def test_data_stream_from_jsonl_is_pickleable(tmp_path): post_pickle_vals = list(pickled_stream) assert pre_pickle_vals == post_pickle_vals + # Interesting: Technically this is a stream of length 1 where the one element is [1,2,3,4,5,6] + validate_data_stream(pickled_stream, 1, list) + + +def test_data_stream_from_json_is_pickleable(tmp_path): + tmpdir = str(tmp_path) + + data = [1, 2, 3, 4, 5, 6] + filepath = os.path.join(tmpdir, "foo.json") + with open(filepath, "w") as f: + json.dump(data, f) + + stream = DataStream.from_json_array(filepath) + + pre_pickle_vals = list(stream) + pickled_stream = pickle.loads(pickle.dumps(stream)) + post_pickle_vals = list(pickled_stream) + + assert pre_pickle_vals == post_pickle_vals + validate_data_stream(pickled_stream, 6, int) def test_bad_json_stream(tmp_path): diff --git a/tests/core/data_model/test_json_dict.py b/tests/core/data_model/test_json_dict.py index 5a82b6490..4ac49da79 100644 --- a/tests/core/data_model/test_json_dict.py +++ b/tests/core/data_model/test_json_dict.py @@ -49,13 +49,27 @@ def test_dict_to_struct_to_dict(): ) assert struct.fields["null_val"].WhichOneof("kind") == "null_value" assert struct.fields["null_val"].null_value == struct_pb2.NullValue.NULL_VALUE - assert isinstance(struct.fields["list_val"].list_value, struct_pb2.ListValue) assert len(struct.fields["list_val"].list_value.values) == len(raw_dict["list_val"]) - assert isinstance(struct.fields["dict_val"].struct_value, struct_pb2.Struct) assert len(struct.fields["dict_val"].struct_value.fields) == len( raw_dict["dict_val"] ) + # NOTE: We cannot do the following isinstance() tests because they fail with proto3 + # because of the temporary descriptor pool copying for the out-of-the-box struct_pb2. + # `assert isinstance(struct.fields["dict_val"].struct_value, struct_pb2.Struct)` + # `assert isinstance(struct.fields["list_val"].list_value, struct_pb2.ListValue)` + # Instead we will just check the expected class names (ignoring the class modules). + assert ( + struct.fields["list_val"].list_value.__class__.__name__ + == struct_pb2.ListValue.__name__ + == "ListValue" + ) + assert ( + struct.fields["dict_val"].struct_value.__class__.__name__ + == struct_pb2.Struct.__name__ + == "Struct" + ) + def test_dict_to_struct_invalid_value(): """Make sure that a ValueError is raised if a bad type is encountered""" diff --git a/tests/core/model_management/test_multi_model_finder.py b/tests/core/model_management/test_multi_model_finder.py index 0d3ddf31f..2d9cfeffc 100644 --- a/tests/core/model_management/test_multi_model_finder.py +++ b/tests/core/model_management/test_multi_model_finder.py @@ -28,6 +28,7 @@ from caikit.core.model_management.multi_model_finder import MultiModelFinder from tests.conftest import temp_config from tests.core.helpers import TestFinder +import caikit ## Helpers ##################################################################### @@ -113,11 +114,17 @@ def test_multi_model_finder_first_not_found(test_finder_config, good_model_path) assert finder.find_model(good_model_path) -def test_multi_model_finder_not_found(): +def test_multi_model_finder_not_found(reset_globals): """Make sure that a simple proxy to local works""" with temp_config_finder() as finder: assert isinstance(finder, MultiModelFinder) assert finder.find_model("not/a/valid/path") is None + with pytest.raises(ValueError) as e: + caikit.core.load("bad/path/to/model") + assert ( + e.value.args[0] + == "value check failed: Unable to find a ModuleConfig for bad/path/to/model" + ) @pytest.mark.parametrize( diff --git a/tests/core/model_management/test_multi_model_initializer.py b/tests/core/model_management/test_multi_model_initializer.py new file mode 100644 index 000000000..0d83a270a --- /dev/null +++ b/tests/core/model_management/test_multi_model_initializer.py @@ -0,0 +1,101 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the MultiModelFinder +""" +# Standard +from contextlib import contextmanager + +# Third Party +import pytest + +# First Party +import aconfig + +# Local +from caikit.core.model_management.factories import model_initializer_factory +from caikit.core.model_management.local_model_finder import LocalModelFinder +from caikit.core.model_management.model_initializer_base import ModelInitializerBase +from caikit.core.modules import ModuleConfig +from tests.conftest import temp_config + +## Helpers ##################################################################### + +# Add bad initializer to model factory +class BadModelInitializer(ModelInitializerBase): + name = "BAD" + + def __init__(self, config: aconfig.Config, instance_name: str): + """A FactoryConstructible object must be constructed with a config + object that it uses to pull in all configuration + """ + pass + + def init( + self, + model_config, + **kwargs, + ): + raise ValueError("Bad Model Initializer") + + +model_initializer_factory.register(BadModelInitializer) + + +@contextmanager +def construct_mm_initializer(multi_model_config, config_override={}): + config_override = config_override or { + "model_management": { + "initializers": { + "local": { + "type": "LOCAL", + }, + "bad": {"type": "BAD"}, + } + } + } + + with temp_config(config_override, "merge"): + model_config = { + "type": "MULTI", + "config": multi_model_config, + } + yield model_initializer_factory.construct(model_config, "instance_name") + + +## Tests ####################################################################### + + +@pytest.mark.parametrize( + ["initializers", "load_successful"], + [[["local"], True], [["bad", "local"], True], [["bad"], False]], +) +def test_multi_model_initializer(good_model_path, initializers, load_successful): + finder = LocalModelFinder(aconfig.Config({}), "local") + config = finder.find_model(good_model_path) + with construct_mm_initializer( + {"initializer_priority": initializers} + ) as initializer: + if load_successful: + assert initializer.init(config) + else: + assert not initializer.init(config) + + +def test_multi_model_initializer_bad_config(): + config = ModuleConfig({"module_id": "bad"}) + with construct_mm_initializer( + {"initializer_priority": ["bad", "local"]} + ) as initializer: + assert not initializer.init(config) diff --git a/tests/core/modules/test_module.py b/tests/core/modules/test_module.py index eaab139f7..0ecaba7c2 100644 --- a/tests/core/modules/test_module.py +++ b/tests/core/modules/test_module.py @@ -156,6 +156,61 @@ def test_save_module(model_path): assert module_saver.config.get("module_paths") == {"dummy": "./dummy"} +def test_save_please_dont_destroy(model_path): + dummy_model = caikit.core.load(model_path) + + with tempfile.TemporaryDirectory() as tempdir: + # exist_ok=False raises FileExistsError + with pytest.raises(FileExistsError): + with ModuleSaver( + dummy_model, + model_path=tempdir, + exist_ok=False, + ): + pass + + # Existing dir should definitely not be removed + assert os.path.exists(tempdir) + + # exist_ok=True does not raise error + with ModuleSaver( + dummy_model, + model_path=tempdir, + exist_ok=True, + ): + pass + assert os.path.exists(tempdir) + + # exist_ok=True and exception thrown -> please don't destroy! + with pytest.raises(ValueError): + with ModuleSaver( + dummy_model, + model_path=tempdir, + exist_ok=True, + ): + raise ValueError # a test exception + assert os.path.exists(tempdir) + + +def test_save_okay_to_destroy(model_path): + dummy_model = caikit.core.load(model_path) + + with tempfile.TemporaryDirectory() as tempdir: + test_subdir = os.path.join(tempdir, "subdir") + + # exist_ok=False and exception thrown -> okay to destroy + with pytest.raises(ValueError): + with ModuleSaver( + dummy_model, + model_path=test_subdir, + exist_ok=False, + ): + assert os.path.exists(test_subdir) + raise ValueError # a test exception + + assert not os.path.exists(test_subdir) + + def test_save_module_kwargs_get_piped_through(model_path): """Testing additional keywords passed to save_module are processed by the module""" dummy_model = caikit.core.load(model_path) diff --git a/tests/core/test_imports.py b/tests/core/test_imports.py index 468cb2d3f..a29113183 100644 --- a/tests/core/test_imports.py +++ b/tests/core/test_imports.py @@ -19,3 +19,7 @@ def test_caikit_core_has_DataValidationError(): def test_caikit_core_has_error_handler(): assert hasattr(caikit.core, "error_handler") + +def test_caikit_core_has_runtime_anmes(): + import caikit.runtime + assert hasattr(caikit.runtime, "names") diff --git a/tests/core/test_task.py b/tests/core/test_task.py index 13082193b..27c7e201f 100644 --- a/tests/core/test_task.py +++ b/tests/core/test_task.py @@ -162,6 +162,9 @@ def run_third_task(self, foo: int) -> SampleOutputType: for t in [FirstTask, SecondTask, ThirdTask]: assert t in MultiTaskChildModule.tasks + # Make sure no tasks are double-counted + assert len(MultiTaskChildModule.tasks) == len(MultiTaskChildModule._TASK_CLASSES) + def test_task_is_not_required_for_modules(): @caikit.core.modules.module(id=str(uuid.uuid4()), name="Stuff", version="0.0.1") diff --git a/tests/core/toolkit/concurrency/__init__.py b/tests/core/toolkit/concurrency/__init__.py new file mode 100644 index 000000000..2068258bf --- /dev/null +++ b/tests/core/toolkit/concurrency/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/core/toolkit/test_destroyable_process.py b/tests/core/toolkit/concurrency/test_destroyable_process.py similarity index 78% rename from tests/core/toolkit/test_destroyable_process.py rename to tests/core/toolkit/concurrency/test_destroyable_process.py index 6c6976c4c..002e0ca5c 100644 --- a/tests/core/toolkit/test_destroyable_process.py +++ b/tests/core/toolkit/concurrency/test_destroyable_process.py @@ -20,13 +20,18 @@ import pytest # Local -from caikit.core.toolkit.destroyable_process import DestroyableProcess +from caikit.core.toolkit.concurrency.destroyable_process import DestroyableProcess +from tests.core.toolkit.concurrency.test_exception_pickler import ( + ReallyPoorlyBehavedException, + get_traceback, +) ## Helpers ##################################################################### EXPECTED_THROW = ValueError("test-any-error") EXPECTED_SUCCESS = "test-any-result" +UNPICKLABLE_ERROR = ReallyPoorlyBehavedException(message="This will not pickle") def infinite_wait(): @@ -42,6 +47,17 @@ def thrower(): raise EXPECTED_THROW +def nested_thrower(): + try: + raise EXPECTED_THROW + except Exception as e: + raise ValueError("some other error!") from e + + +def bad_thrower(): + raise UNPICKLABLE_ERROR + + def succeeder(): return EXPECTED_SUCCESS @@ -164,3 +180,30 @@ def test_default_event_is_set_on_completion(process_type): proc.start() proc.join() assert proc.completion_event.is_set() + + +def test_process_can_raise_unpicklable_exception(process_type): + proc = DestroyableProcess(process_type, bad_thrower) + proc.start() + proc.join() + + assert proc.threw + exception = proc.error + + assert str(UNPICKLABLE_ERROR) in str(exception) + + +def test_process_can_raise_nested_exception(process_type): + proc = DestroyableProcess(process_type, nested_thrower) + proc.start() + proc.join() + + assert proc.threw + + assert proc.error.__cause__ is not None + tb = get_traceback(proc.error) + assert len(tb) > 2 + assert ( + "The above exception was the direct cause of the following exception:" + in "".join(tb) + ) diff --git a/tests/core/toolkit/test_destroyable_thread.py b/tests/core/toolkit/concurrency/test_destroyable_thread.py similarity index 98% rename from tests/core/toolkit/test_destroyable_thread.py rename to tests/core/toolkit/concurrency/test_destroyable_thread.py index 804818869..0489c0b78 100644 --- a/tests/core/toolkit/test_destroyable_thread.py +++ b/tests/core/toolkit/concurrency/test_destroyable_thread.py @@ -26,7 +26,7 @@ import pytest # Local -from caikit.core.toolkit.destroyable_thread import ( +from caikit.core.toolkit.concurrency.destroyable_thread import ( DestroyableThread, ThreadDestroyedException, ) diff --git a/tests/core/toolkit/concurrency/test_exception_pickler.py b/tests/core/toolkit/concurrency/test_exception_pickler.py new file mode 100644 index 000000000..96ab74091 --- /dev/null +++ b/tests/core/toolkit/concurrency/test_exception_pickler.py @@ -0,0 +1,155 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +import pickle +import traceback + +# Third Party +import pytest + +# Local +from caikit.core.toolkit.concurrency.pickling_exception import ( + ExceptionPickler, + PickleFailureFallbackException, +) + + +#### Custom exception types for testing ################################## +class WellBehavedException(Exception): + def __init__(self, *args): + self.message = args[0] + self.thing = args[1] + + +class PoorlyBehavedException(Exception): + def __init__(self, message, private_arg): + self.message = message + self._private_arg = private_arg + + def __str__(self): + return f"{self.message} ({self._private_arg})" + + +class ReallyPoorlyBehavedException(Exception): + def __init__(self, message): + self._something_entirely_different = message + + def __str__(self): + return self._something_entirely_different + + +def raise_from_all(*exceptions): + """Given many exceptions, raise each one from the last one to make a chain of exceptions""" + last_exception = None + chain = [] + + for e in exceptions: + try: + if last_exception: + raise e from last_exception + raise e + except Exception as e: + last_exception = e + chain.insert(0, last_exception) + + return chain + + +def get_traceback(exc): + """Provides 0!"), + ReallyPoorlyBehavedException(message="input error!"), + PoorlyBehavedException( + message="Failed to validate input data", private_arg="code 7260" + ), + WellBehavedException("Training failed due to invalid input", "retryable=False"), + ) + + pickler = ExceptionPickler(chain[0]) + pickler = pickle.loads(pickle.dumps(pickler)) + + exception = pickler.get() + + assert isinstance(exception.__cause__.__cause__.__cause__, ValueError) + + tb = "".join(get_traceback(exception)) + for exc in chain: + assert str(exc) in tb + + +def test_pickler_works_with_mix_of_arg_and_kwarg(): + exception = PoorlyBehavedException( + "this is my arg", private_arg="and this is my kwarg" + ) + + unpickled = pickle.loads(pickle.dumps(ExceptionPickler(exception))).get() + + assert isinstance(unpickled, PoorlyBehavedException) + assert str(exception) == str(unpickled) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index d35b20d59..8a18ae448 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from os import getenv, path +from os import path import asyncio import subprocess @@ -66,3 +66,47 @@ def test_example_text_sentiment(): # Client worked well, let's stop the server server.terminate() + + +@pytest.mark.skip("Skipping until we figure out how to parallelize tests") +@pytest.mark.examples +def test_example_sample_lib(): + # Example specific grpc port + grpc_port = 8085 + with requirements("sample_lib") as (python_venv, example_dir): + # Start the server + with subprocess.Popen( + [python_venv, "start_runtime_with_sample_lib.py"], + cwd=example_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as server: + # Check if the gRPC port is open + # The gRPC server start-up time has some inherent variability + # 60s timeout should cover most situations, while keeping the + # test execution time reasonable + if not asyncio.run(waitForPort(grpc_port, 60)): + server.terminate() + pytest.fail( + "Failed to connect to the gRPC server on port {} in 30s.".format( + grpc_port + ) + ) + + # Server is running, start the client + # Use a timeout of 10s for inference. Capture outputs to report + # them in case of failure. + try: + subprocess.run( + [python_venv, path.join(example_dir, "client.py")], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=30, + ) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + server.terminate() + pytest.fail("Client failed with output: {}".format(e)) + + # Client worked well, let's stop the server + server.terminate() diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 9e2360c23..5e439643f 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -55,6 +55,7 @@ class TestContext: def __init__(self, model_id): self.model_id = model_id self.callbacks = [] + self.canceled = False def invocation_metadata(self): return [("mm-model-id", self.model_id)] @@ -63,9 +64,11 @@ def add_callback(self, some_function, *args, **kwargs): self.callbacks.append( {"func": some_function, "args": args, "kwargs": kwargs} ) - return True + # Only return true if the call has not yet canceled + return not self.canceled def cancel(self): + self.canceled = True [f["func"](*f["args"], **f["kwargs"]) for f in self.callbacks] return TestContext(model_id) diff --git a/tests/fixtures/models/primitive/config.yml b/tests/fixtures/models/primitive/config.yml new file mode 100644 index 000000000..9097aedcc --- /dev/null +++ b/tests/fixtures/models/primitive/config.yml @@ -0,0 +1,14 @@ +module_class: sample_lib.modules.sample_task.SamplePrimitiveModule +module_id: 00112233-0405-0607-0809-0a0b02dd0e0f +created: "2023-03-28 16:34:58.720898" +name: SampleModule +sample_lib_version: 0.0.1 +saved: "2023-03-28 16:34:58.720929" +tracking_id: e56dbd48-9231-432b-ad04-c02eeffdd158 +version: 0.0.1 +train: + training_params_json_dict: + foo: + bar: 123 + training_params_dict: + layer_sizes: 1 diff --git a/tests/fixtures/sample_lib/data_model/sample.py b/tests/fixtures/sample_lib/data_model/sample.py index c09a4a0a7..205f24c17 100644 --- a/tests/fixtures/sample_lib/data_model/sample.py +++ b/tests/fixtures/sample_lib/data_model/sample.py @@ -11,6 +11,11 @@ # Local from caikit.core import DataObjectBase, TaskBase, dataobject, task from caikit.core.data_model import ProducerId +from caikit.core.data_model.json_dict import JsonDict +from caikit.core.exceptions.caikit_core_exception import ( + CaikitCoreException, + CaikitCoreStatusCode, +) from caikit.interfaces.common.data_model import File @@ -29,6 +34,18 @@ class SampleListInputType(DataObjectBase): inputs: List[SampleInputType] +# Test w/ just import and no dataobject +@dataobject(package="caikit_data_model.sample_lib") +class JsonDictInputType(DataObjectBase): + """A sample `JsonDict` input type for this library. + + This exists because it impacts test_json_dict.py testing under proto3. + This class is not used, but it affects the descriptor pool behavior. + """ + + jd: JsonDict + + @dataobject(package="caikit_data_model.sample_lib") class FileInputType(DataObjectBase): """A simple type for tasks that deal with file data""" @@ -36,6 +53,13 @@ class FileInputType(DataObjectBase): file: File metadata: SampleInputType + def __post_init__(self): + if self.file.filename and ".exe" in self.file.filename: + raise CaikitCoreException( + status_code=CaikitCoreStatusCode.INVALID_ARGUMENT, + message="Executables are not a supported File type", + ) + @dataobject(package="caikit_data_model.sample_lib") class SampleOutputType(DataObjectBase): diff --git a/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py b/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py index 77868632f..e1b54646a 100644 --- a/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py +++ b/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py @@ -73,11 +73,12 @@ def run( @SampleTask.taskmethod(output_streaming=True) def run_stream_out( - self, sample_input: SampleInputType + self, sample_input: SampleInputType, err_stream: bool = False ) -> DataStream[SampleOutputType]: """ Args: sample_input (sample_lib.data_model.SampleInputType): the input + err_stream (bool): An optional parameter to error out the stream Returns: caikit.core.data_model.DataStream[sample_lib.data_model.SampleOutputType]: The output @@ -87,7 +88,15 @@ def run_stream_out( SampleOutputType(f"Hello {sample_input.name} stream") for x in range(self.stream_size) ] - stream = DataStream.from_iterable(list_) + # raise a value error when the stream is iterated, not before. + def raise_exception(): + raise ValueError("raising a ValueError") + + stream = ( + DataStream.from_iterable(list_) + if not err_stream + else DataStream.from_iterable([1]).map(lambda x: raise_exception()) + ) return stream @SampleTask.taskmethod(input_streaming=True, output_streaming=True) diff --git a/tests/interfaces/common/test_vectors.py b/tests/interfaces/common/test_vectors.py new file mode 100644 index 000000000..aa315508d --- /dev/null +++ b/tests/interfaces/common/test_vectors.py @@ -0,0 +1,152 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for embedding vectors +""" +# Standard +from collections import namedtuple + +# Third Party +import numpy as np +import pytest + +# Local +from caikit.interfaces.common import data_model as dm + +## Setup ######################################################################### + +DUMMY_VECTOR_SHAPE = (5,) +RANDOM_SEED = 77 +np.random.seed(RANDOM_SEED) +random_number_generator = np.random.default_rng() + +# To tests the limits of our type-checking, this can replace our legit data objects +TRICK_SEQUENCE = namedtuple("Trick", "values") + + +@pytest.fixture +def simple_array_of_floats(): + return [1.1, 2.2] + + +@pytest.fixture +def simple_array_of_ints(): + return ["foo", 1, 2, 3, 4] + + +@pytest.fixture +def random_numpy_vector1d_float32(): + return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float32) + + +@pytest.fixture +def random_numpy_vector1d_float64(): + return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float64) + + +@pytest.fixture +def random_python_vector1d_float(random_numpy_vector1d_float32): + return random_numpy_vector1d_float32.tolist() + + +## Tests ######################################################################## + + +@pytest.mark.parametrize( + "sequence", + [ + dm.PyFloatSequence(), + dm.NpFloat32Sequence(), + dm.NpFloat64Sequence(), + TRICK_SEQUENCE(values=None), + ], + ids=type, +) +def test_empty_sequences(sequence): + """No type check error with empty sequences""" + new_dm_from_init = dm.Vector1D(sequence) + assert isinstance(new_dm_from_init.data, type(sequence)) + assert new_dm_from_init.data.values is None + + # Test proto + proto_from_dm = new_dm_from_init.to_proto() + new_dm_from_proto = dm.Vector1D.from_proto(proto_from_dm) + assert isinstance(new_dm_from_proto, dm.Vector1D) + assert new_dm_from_proto.data.values is None + + # Test json + json_from_dm = new_dm_from_init.to_json() + new_dm_from_json = dm.Vector1D.from_json(json_from_dm) + assert isinstance(new_dm_from_json, dm.Vector1D) + assert new_dm_from_json.data.values == [] + + +def test_vector1d_iterator_error(): + """Cannot just shove in an iterator and expect it to work""" + with pytest.raises(ValueError): + dm.Vector1D(data=[1.1, 2.2, 3.3]) + + +def _assert_array_check(new_array, data_values, float_type): + for value in new_array.data.values: + assert isinstance(value, float_type) + np.testing.assert_array_equal(new_array.data.values, data_values) + + +@pytest.mark.parametrize( + "float_seq_class, random_values, float_type", + [ + (dm.PyFloatSequence, "random_python_vector1d_float", float), + (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), + ( + TRICK_SEQUENCE, + "simple_array_of_floats", + float, + ), # Sneaky but tests corner cases for now + ], +) +def test_vector1d_dm(float_seq_class, random_values, float_type, request): + + # Test init + fixture_values = request.getfixturevalue(random_values) + dm_init = dm.Vector1D(data=float_seq_class(fixture_values)) + _assert_array_check(dm_init, fixture_values, float_type) + + # Test proto + dm_to_proto = dm_init.to_proto() + dm_from_proto = dm.Vector1D.from_proto(dm_to_proto) + _assert_array_check(dm_from_proto, fixture_values, float_type) + + # Test json + dm_to_json = dm_init.to_json() + dm_from_json = dm.Vector1D.from_json(dm_to_json) + _assert_array_check( + dm_from_json, fixture_values, float + ) # NOTE: always float after json + + +@pytest.mark.parametrize( + "float_seq_class, random_values, float_type", + [ + (dm.PyFloatSequence, "random_python_vector1d_float", float), + (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), + ], +) +def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type, request): + fixture_values = request.getfixturevalue(random_values) + v = dm.Vector1D.from_vector(fixture_values) + assert isinstance(v.data, float_seq_class) + assert isinstance(v.data.values[0], float_type) + _assert_array_check(v, fixture_values, float_type) diff --git a/tests/interfaces/nlp/test_classification.py b/tests/interfaces/nlp/test_classification.py index e5949eec5..8015107fd 100644 --- a/tests/interfaces/nlp/test_classification.py +++ b/tests/interfaces/nlp/test_classification.py @@ -22,6 +22,10 @@ TokenClassificationResult, TokenClassificationResults, ) +from caikit.interfaces.nlp.data_model.classification import ( + InputWarning, + InputWarningReason, +) ## Setup ######################################################################### @@ -60,8 +64,14 @@ generated_token_count=7, finish_reason=FinishReason.STOP_SEQUENCE, seed=42, + warnings=[ + InputWarning( + id=InputWarningReason.UNSUITABLE_INPUT, message="unsuitable input detected" + ) + ], ) + ## Tests ######################################################################## ### ClassificationResult @@ -260,3 +270,5 @@ def _validate_classification_generated_text_result(obj): assert obj.generated_token_count == 7 assert obj.finish_reason == 5 assert obj.seed == 42 + assert obj.warnings[0].id == InputWarningReason.UNSUITABLE_INPUT.value + assert obj.warnings[0].message == "unsuitable input detected" diff --git a/tests/interfaces/nlp/test_reranker.py b/tests/interfaces/nlp/test_reranker.py new file mode 100644 index 000000000..16840ef25 --- /dev/null +++ b/tests/interfaces/nlp/test_reranker.py @@ -0,0 +1,188 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for reranker +""" + +# Standard +import random +import string + +# Third Party +import pytest + +# Local +from caikit.interfaces.nlp import data_model as dm + +## Setup ######################################################################### + + +@pytest.fixture +def input_document(): + return { + "text": "this is the input text", + "_text": "alternate _text here", + "title": "some title attribute here", + "anything": "another string attribute", + "str_test": "test string", + "int_test": 1234, + "float_test": 9876.4321, + } + + +@pytest.fixture +def input_random_document(): + return { + "text": "".join(random.choices(string.printable, k=100)), + "random_str": "".join(random.choices(string.printable, k=100)), + "random_int": random.randint(-99999, 99999), + "random_float": random.uniform(-99999, 99999), + } + + +@pytest.fixture +def input_documents(input_document, input_random_document): + return [input_document, input_random_document] + + +@pytest.fixture +def input_score(input_document): + return { + "document": input_document, + "index": 1234, + "score": 9876.54321, + "text": "this is the input text", + } + + +@pytest.fixture +def input_random_score(input_random_document): + return { + "document": input_random_document, + "index": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), + } + + +@pytest.fixture +def input_random_score_3(): + return { + "document": {"text": "random foo3"}, + "index": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), + } + + +@pytest.fixture +def input_scores(input_score, input_random_score): + return [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] + + +@pytest.fixture +def input_scores2(input_random_score, input_random_score_3): + return [ + dm.RerankScore(**input_random_score), + dm.RerankScore(**input_random_score_3), + ] + + +@pytest.fixture +def input_result_1(input_scores): + return {"result": dm.RerankScores(query="foo", scores=input_scores)} + + +@pytest.fixture +def input_result_2(input_scores2): + return {"result": dm.RerankScores(query="bar", scores=input_scores2)} + + +@pytest.fixture +def input_results(input_scores, input_scores2): + return { + "results": [ + dm.RerankScores(query="foo", scores=input_scores), + dm.RerankScores(query="bar", scores=input_scores2), + ] + } + + +@pytest.fixture +def input_sentence_similarity_scores_1(): + return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} + + +@pytest.fixture +def input_sentence_similarity_result(input_sentence_similarity_scores_1): + return {"result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1)} + + +@pytest.fixture +def input_sentence_similarity_scores_2(): + return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} + + +@pytest.fixture +def input_sentence_similarities_scores( + input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 +): + return [ + dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1), + dm.SentenceSimilarityScores(**input_sentence_similarity_scores_2), + ] + + +@pytest.fixture +def input_sentence_similarity_results(input_sentence_similarities_scores): + return {"results": input_sentence_similarities_scores} + + +## Tests ######################################################################## + + +@pytest.mark.parametrize( + "data_object, inputs", + [ + (dm.RerankScore, "input_score"), + (dm.RerankScore, "input_random_score"), + (dm.RerankResult, "input_result_1"), + (dm.RerankResults, "input_results"), + (dm.SentenceSimilarityResult, "input_sentence_similarity_result"), + (dm.SentenceSimilarityResults, "input_sentence_similarity_results"), + ], +) +def test_data_object(data_object, inputs, request): + # Init data object + fixture_values = request.getfixturevalue(inputs) + new_do_from_init = data_object(**fixture_values) + assert isinstance(new_do_from_init, data_object) + assert_fields_match(new_do_from_init, fixture_values) + + # Test to/from proto + proto_from_dm = new_do_from_init.to_proto() + new_do_from_proto = data_object.from_proto(proto_from_dm) + assert isinstance(new_do_from_proto, data_object) + assert_fields_match(new_do_from_proto, fixture_values) + assert new_do_from_init == new_do_from_proto + + # Test to/from json + json_from_dm = new_do_from_init.to_json() + new_do_from_json = data_object.from_json(json_from_dm) + assert isinstance(new_do_from_json, data_object) + assert_fields_match(new_do_from_json, fixture_values) + assert new_do_from_init == new_do_from_json + + +def assert_fields_match(data_object, inputs): + assert all(getattr(data_object, key) == value for key, value in inputs.items()) diff --git a/tests/interfaces/ts/__init__.py b/tests/interfaces/ts/__init__.py new file mode 100644 index 000000000..2068258bf --- /dev/null +++ b/tests/interfaces/ts/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/interfaces/ts/data_model/__init__.py b/tests/interfaces/ts/data_model/__init__.py new file mode 100644 index 000000000..2068258bf --- /dev/null +++ b/tests/interfaces/ts/data_model/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/interfaces/ts/data_model/test_single_timeseries.py b/tests/interfaces/ts/data_model/test_single_timeseries.py new file mode 100644 index 000000000..a460fb4f9 --- /dev/null +++ b/tests/interfaces/ts/data_model/test_single_timeseries.py @@ -0,0 +1,820 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the Timeseries data model object +""" + +# Standard +from datetime import timezone +from typing import Union +import datetime as dt +import json +import warnings + +# Third Party +from pandas import RangeIndex +import dateutil +import numpy as np +import pandas as pd +import pyspark.sql +import pytest + +# Local +from caikit.interfaces.ts.data_model._single_timeseries import SingleTimeSeries +from caikit.interfaces.ts.data_model.backends._spark_backends import ( + SparkTimeSeriesBackend, + ensure_spark_cached, +) +from caikit.interfaces.ts.data_model.backends.spark_util import iteritems_workaround +from caikit.interfaces.ts.data_model.backends.util import ( + pd_timestamp_to_seconds, + strip_periodic, +) +from tests.interfaces.ts.data_model.util import ( + create_extended_test_dfs, + df_project, + key_helper, +) +from tests.interfaces.ts.helpers import test_log +import caikit.interfaces.ts.data_model as dm + +warnings.filterwarnings("ignore", category=ResourceWarning) + +## Helpers ##################################################################### + +# There is a large variety in how the timestamps and columns can be represented +# in a pandas DataFrame, so we need to handle all of the cases: +# +# Timestamp Range Types: +# - Int numeric +# - Float numeric +# - Date range on a regular timedelta (e.g. Day) +# - Date range on an irregular frequence (e.g. Business Day) +# - Periodic date on a regular timedelta (e.g. Day) +# - Periodic date on an irregular frequence (e.g. Business Day) +# +# Timestamp Location: +# - Column +# - Index +# +# Value Columns: +# - Numeric keys +# - String keys + +# Standard reusable value columns +value_cols = [range(3), np.arange(0, 1.5, 0.5)] +value_cols_dict = ({f"val_{i}": val_col for i, val_col in enumerate(value_cols)}, float) +value_rows = (list(zip(*value_cols)), float) +value_rows_str = ([(x[0], f"value:{x[1]}") for x in value_rows[0]], "string") +value_rows_any = ([(x[0], f"value:{x[1]}") for x in value_rows[0]], object) +value_rows_list = ([(x[0], [x[1], x[1], x[1]]) for x in value_rows[0]], object) +value_rows_period = ( + [ + (x[0], dt.datetime(year=2022, month=1, day=int(round(x[1] + 1)))) + for x in value_rows[0] + ], + "datetime64[ns]", +) + + +def get_ts_sequence(df: pd.DataFrame, ts_source: Union[str, int]) -> pd.Series: + """Helper to pull the sequence based on where the source is""" + key = key_helper(df, ts_source) + return ( + RangeIndex(start=0, stop=df_project(df).shape[0], step=1) + if key is None + else df_project(df)[key] + ) + + +# Timestamp range types +int_numeric_range = range(0, 30, 10) +float_numeric_range = np.arange(0, 1.5, 0.5) +date_range_regular = pd.date_range("2000", freq="D", periods=3) +date_range_irregular = pd.date_range("2000", freq="B", periods=3) +period_range_regular = pd.period_range("2000", freq="D", periods=3) +period_range_irregular = pd.period_range("2000", freq="B", periods=3) + +# Reusable timestamp column name +# NOTE: This is _intentionally_ not the same as the default! +default_ts_col_name = "ts" + +# All testable data frame configurations! +testable_data_frames = [] +for ts_range in [ + int_numeric_range, + float_numeric_range, + date_range_regular, + date_range_irregular, + period_range_regular, + period_range_irregular, +]: + # Data column types + for data_arg, data_type in [ + value_rows, + value_cols_dict, + value_rows_str, + value_rows_any, + value_rows_period, + value_rows_list, + ]: + cur_df = pd.DataFrame(data_arg, index=ts_range) + # be explicit about type otherwise goes to Any + cur_df = cur_df.astype({cur_df.columns[1]: data_type}) + cur_df.columns = cur_df.columns.astype(str) + + # Add df w/ ts in index + testable_data_frames.append( + ( + cur_df, + None, + ) + ) + + # Add df w/ ts in column + if isinstance(data_arg, dict): + full_data_arg = dict(**{default_ts_col_name: ts_range}, **data_arg) + ts_col_name = default_ts_col_name + else: + full_data_arg = [[ts] + list(vals) for ts, vals in zip(ts_range, data_arg)] + ts_col_name = "0" + + cur_df = pd.DataFrame(full_data_arg) + cur_df.columns = cur_df.columns.astype(str) + testable_data_frames.append((cur_df, ts_col_name)) + # let's append spark dataframes to this + +original_length = len(testable_data_frames) +test_log.debug("Made a total of %d testable data frames!", original_length) + +# replicate and extended the dataframes with pyspark.sql.DataFrame if needed +testable_data_frames = create_extended_test_dfs(testable_data_frames) + + +def check_df_ts_eq( + df: pd.DataFrame, + ts: SingleTimeSeries, + ts_source: Union[str, int], +) -> bool: + """Helper to make sure the actual data in the data frame and the TimeSeries + line up + """ + + ################### + ## Time Sequence ## + ################### + + # some evaluations below require a pandas-like api + dfeval = df_project(df) + + df_ts_range = get_ts_sequence(dfeval, ts_source) + if not ts.time_sequence: + test_log.debug("No valid time sequence!") + return False + if isinstance(df_ts_range.dtype, pd.PeriodDtype): + # If it's a periodic index, the timeseries may hold this as either a + # PeriodicTimeSequence (if the freq is regular) or a PointTimeSequence + # (if the freq is irregular) + if ts.time_period: + if not ts.time_period.start_time.ts_epoch: + test_log.debug("Start time for periodic not based in the epoch") + return False + if ( + ts.time_period.start_time.ts_epoch.as_datetime().timestamp() + != df_ts_range[0].start_time.timestamp() + ): + test_log.debug( + "Periodic time sequence start time mismatch: %s != %s", + ts.time_period.start_time.ts_epoch.as_datetime(), + df_ts_range[0].start_time, + ) + return False + + # The period may either be a string (pandas period notation) or a + # number of seconds + if ts.time_period.period_length.dt_str: + if ts.time_period.period_length.dt_str != df_ts_range.dtype.freq.name: + test_log.debug( + "Period str duration mismatch: %s != %s", + ts.time_period.period_length.dt_str, + df_ts_range.dtype.freq.name, + ) + return False + + elif not ts.time_period.period_length.dt_sec: + test_log.debug("Period length for periodic not in seconds or str") + return False + elif ( + ts.time_period.period_length.dt_sec.as_timedelta() + != df_ts_range.dtype.freq.delta + ): + test_log.debug( + "Period length mismatch: %s != %s", + ts.time_period.period_length.dt_sec.as_timedelta(), + df_ts_range.dtype.freq.delta, + ) + return False + elif isinstance(df_ts_range, RangeIndex): + if ts.time_period.start_time.ts_int is None: + test_log.debug("Start time for periodic not based in the int") + return False + if ts.time_period.start_time.ts_int != df_ts_range.start: + test_log.debug( + "Periodic time sequence start time mismatch: %s != %s", + ts.time_period.start_time.ts_int, + df_ts_range.start, + ) + return False + + # The period may either be a string (pandas period notation) or a + # number of seconds + if ts.time_period.period_length.dt_int is not None: + if ts.time_period.period_length.dt_int != df_ts_range.step: + test_log.debug( + "Period int duration mismatch: %s != %s", + ts.time_period.period_length.dt_int, + df_ts_range.step, + ) + return False + # If not a periodic index, the dm representation is a sequence of points + else: + if not ts.time_points: + test_log.debug("Sequential sequence not represented as points") + return False + + # Make sure the appropriate point types are used + if len(ts.time_points.points) != len(df_ts_range): + test_log.debug( + "Time point length mismatch: %d != %d", + len(ts.time_points.points), + len(df_ts_range), + ) + return False + + # Compare point values. We use view_point.time which will pull the + # appropriate backing point type + for i, (datamodel_point, df_val) in enumerate( + zip(ts.time_points.points, df_ts_range.to_list()) + ): + test_log.debug( + "Comparing TimePoints of type %s / %s", + type(datamodel_point.time), + type(df_val), + ) + datamodel_val = datamodel_point.time + if isinstance(datamodel_val, dm.Seconds): + datamodel_val = datamodel_val.as_datetime() + datamodel_seconds = pd_timestamp_to_seconds(datamodel_val) + df_seconds = pd_timestamp_to_seconds(df_val) + + if datamodel_seconds != df_seconds: + test_log.debug( + "Point value mismatch: %s != %s", datamodel_seconds, df_seconds + ) + return False + + ############ + ## Values ## + ############ + + df_val_cols = [ + val_label if val_label in dfeval.columns else int(val_label) + for val_label in ts.value_labels or dfeval.columns + ] + test_log.debug("df_val_cols: %s", df_val_cols) + if len(df_val_cols) != len(ts.values): + test_log.debug("Value labels and value columns have mismatched length") + return False + + for df_val_col_key, ts_val_seq in zip(df_val_cols, ts.values): + ts_vals = list(ts_val_seq.sequence.values) + df_val_col = dfeval[df_val_col_key] + if len(df_val_col) != len(ts_vals): + test_log.debug("Column %s has length mismatch", df_val_col_key) + return False + + # TODO: what about Any? + # We currently give back the serialized version when values is called, but should it be the deserialized??? + np_value_col = df_val_col.to_numpy() + if ts_val_seq.val_any is not None: + ts_vals = [json.loads(v) for v in ts_vals] + if ts_val_seq.val_timepoint is not None: + ts_vals = [np.datetime64(dateutil.parser.parse(v)) for v in ts_vals] + + # we have to test each separately since each is a vector + if ts_val_seq.val_vector is not None: + ts_vals = [v for v in ts_vals] + if not len(np_value_col) == len(ts_vals): + test_log.debug("vector lengths didn't match") + return False + for i in range(len(ts_vals)): + # we can get ndarrays here as spark stores in ndarrays + ts_to_check = ( + ts_vals[i].tolist() + if isinstance(ts_vals[i], np.ndarray) + else ts_vals[i] + ) + np_to_check = ( + np_value_col[i].tolist() + if isinstance(np_value_col[i], np.ndarray) + else np_value_col[i] + ) + if not ts_to_check == np_to_check: + test_log.debug( + "Column %s has value mismatch: %s != %s", + df_val_col_key, + df_val_col, + ts_vals, + ) + return False + else: + if not (np_value_col == ts_vals).all(): + test_log.debug( + "Column %s has value mismatch: %s != %s", + df_val_col_key, + df_val_col, + ts_vals, + ) + return False + + # ids is more thoroughly tested in MultiTimeSeries, where it is much more useful + ids = ts.ids + if ids is not None: + return False + + return True + + +def compare_np( + df: pd.DataFrame, np_view: np.ndarray, ts_source: Union[str, int] +) -> bool: + """Compare the output numpy view to the input data frame. The following + conventions should be true: + + 1. The first column of the ndarray should be the time sequence + 2. The ndarray's dtype should be the "lowest common denominator" of the time + sequence and value columns (e.g. object < float < int) + """ + + ts_sequence = get_ts_sequence(df, ts_source) + + # Make sure the time sequence matches + # some equality operators are not supported down at the + # java.util.ArrayList level (where pyspark.sql.DataFrames will go + # down to) + # so do it the old fashioned way + for idx, value in enumerate(np_view[:, df.columns.get_loc(ts_source)]): + if value != ts_sequence.iloc[idx]: + test_log.debug( + "Numpy ts sequence mismatch: %s != %s", + ts_sequence[idx], + value, + ) + return False + + # Make sure the value sequences match + df_np = pd.DataFrame(np_view) + val_cols = [col for col in df.columns if col != ts_source] + if ts_source is None: + np_val_cols_len = len(df_np.columns) + else: + np_val_cols_len = len(df_np.columns) - 1 + + np_val_rows = df_np[[df.columns.get_loc(c) for c in val_cols]].to_numpy() + try: + np.testing.assert_equal( + np_val_rows.flatten(), df[val_cols].to_numpy().flatten() + ) + except AssertionError as _: + test_log.debug("NP view data mismatch: %s != %s", np_val_rows, df[val_cols]) + return False + + return True + + +## Tests ####################################################################### + + +def test_not_serializable_value_val_any(): + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + value_rows_not_serializable = [(x[0], Point(x[1], x[1])) for x in value_rows[0]] + df = pd.DataFrame(value_rows_not_serializable, index=int_numeric_range) + ts = dm.SingleTimeSeries(df) + with pytest.raises(TypeError): + ts.to_json() + + +@pytest.mark.filterwarnings( + "ignore:'PYARROW_IGNORE_TIMEZONE' environment variable was not set.*", + "ignore:`to_list` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:`to_numpy` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", +) +def test_create_single_timeseries_dm(): + df = pd.DataFrame({"time_tick": [0, 1, 2], "value": [1.0, 2.0, 3.0]}) + ts = dm.SingleTimeSeries(df, timestamp_column="time_tick", value_columns=["value"]) + mts = dm.TimeSeries(timeseries=ts) + + assert mts.id_labels == [] + assert mts.timeseries[0].to_proto() == ts.to_proto() + assert mts.timeseries[0].to_json() == ts.to_json() + + reserved_key = dm.TimeSeries._DEFAULT_ID_COL + + spark_mts = mts.as_spark(is_multi=True) + assert isinstance(spark_mts, pyspark.sql.DataFrame) + assert reserved_key in spark_mts.columns + + spark_ts = mts.as_spark() + assert reserved_key not in spark_ts.columns + + pandas_mts = mts.as_pandas(is_multi=True) + assert isinstance(pandas_mts, pd.DataFrame) + assert reserved_key in pandas_mts.columns + + pandas_ts = mts.as_pandas() + assert reserved_key not in pandas_ts.columns + + +def test_no_such_attribute_val_seq(): + value_rows_not_serializable = [(x[0], x[1]) for x in value_rows[0]] + df = pd.DataFrame(value_rows_not_serializable, index=int_numeric_range) + ts = dm.SingleTimeSeries(df) + val_seq = ts.values[0]._backend + with pytest.raises(AttributeError): + val_seq.get_attribute(ts.values[0], "bad") + + +# todo +# Looks like if we have dt_str and seconds, if we are not on a boundary, it gets truncated, that might be what we want +# but can address later +start_times = [ + {"ts_epoch": {"seconds": 946702784.0}}, + {"ts_int": 946702784}, + {"ts_float": 946702784.0}, +] +period_lengths = [ + {"dt_str": "D"}, + {"dt_int": 1}, + {"dt_float": 2.0}, + {"dt_sec": {"seconds": 3}}, +] +periodic_time_seq_input = [] +for start_time in start_times: + for period_length in period_lengths: + periodic_time_seq_input.append((start_time, period_length)) + + +@pytest.mark.parametrize("input", periodic_time_seq_input) +def test_periodic_time_sequence_round_trip(input): + start_time, period_length = input + json_str = json.dumps({"startTime": start_time, "periodLength": period_length}) + periodic_time_sequence = dm.PeriodicTimeSequence.from_json(json_str) + periodic_time_sequence_proto = periodic_time_sequence.to_proto() + periodic_time_sequence = dm.PeriodicTimeSequence.from_proto( + periodic_time_sequence_proto + ) + + # todo we need to handle this issue with camelcase being required for from_json + assert ( + periodic_time_sequence.to_dict()["start_time"] + == json.loads(json_str)["startTime"] + ) + k = next(iter(period_length)) + assert ( + periodic_time_sequence.to_dict()["period_length"][k] + == json.loads(json_str)["periodLength"][k] + ) + + +period_lengths = ["D", 1, 2.0, dm.Seconds.from_json(json.dumps({"seconds": 3}))] +results = ["D", 1, 2.0, {"seconds": 3}] + + +@pytest.mark.parametrize("input", period_lengths) +def test_time_duration_time_attribute(input): + time_duration = dm.TimeDuration(time=input) + assert time_duration.time == input + + +@pytest.mark.skip( + "Raising an error for invalid oneof field values hasn't been implemented" +) +def test_time_duration_bad_attribute(): + with pytest.raises(AttributeError): + _ = dm.TimeDuration(time=True) + + +time_points = [ + 946702784, + 946702784.0, + dm.Seconds.from_json(json.dumps({"seconds": 3})), +] + + +@pytest.mark.parametrize("input", time_points) +def test_time_point_time_attribute(input): + time_point = dm.TimePoint(time=input) + assert time_point.time == input + + +@pytest.mark.skip( + "Raising an error for invalid oneof field values hasn't been implemented" +) +def test_time_point_time_bad_attribute(): + with pytest.raises(AttributeError): + _ = dm.TimePoint(time=True) + + +@pytest.mark.skip( + "Raising an error for invalid oneof field values hasn't been implemented" +) +def test_time_duration_never_set(): + with pytest.raises(AttributeError): + _ = dm.TimeDuration() + + +def test_seconds(): + # setattr test + seconds = dm.Seconds(seconds=1) + assert dt.timedelta(seconds=1) == seconds.as_timedelta() + + # from timedelta + seconds = dm.Seconds.from_timedelta(dt.timedelta(seconds=2)) + assert dt.timedelta(seconds=2) == seconds.as_timedelta() + + # from timestamp + # Third Party + import pytz + + seconds = dm.Seconds.from_datetime(dt.datetime(1990, 1, 1, tzinfo=timezone.utc)) + assert seconds.to_dict() == {"seconds": 631152000} + + +def test_empty_val_sequence(): + seq = dm.ValueSequence() + assert seq.sequence is None + + +def get_df_len(df_in): + if isinstance(df_in, pd.DataFrame): + return len(df_in) + else: + return len(df_in.toPandas()) + + +def get_col_list(df_in, col): + if isinstance(df_in, pd.DataFrame): + return df_in[col].values.tolist() + else: + return df_in.toPandas()[col].values.tolist() + + +@pytest.mark.filterwarnings( + "ignore:'PYARROW_IGNORE_TIMEZONE' environment variable was not set.*", + "ignore:`to_list` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:`to_numpy` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", +) +@pytest.mark.parametrize("df_ts_data", testable_data_frames) +def test_timeseries_pd(df_ts_data): + """Tests for TimeSeries objects backed by Pandas data frames. This test is + parametrized over ALL different flavors of data frame since all layouts + should behave the same! + """ + df, ts_source = df_ts_data + + with ensure_spark_cached(df) as df: + # this doesn't for spark dataframes + test_log.debug("Running test_timeseries_pd:\n%s", df) + test_log.debug("ts_source: %s", ts_source) + + ts = dm.SingleTimeSeries(df, timestamp_column=ts_source) + + # Verify that the pandas view round-trips (and doesn't make a copy) + # if we're using a spark backend, this is not a valid expectation + # we need to check if ts_source isn't none as if it is none we will get a new dataframe that + if not isinstance(ts._backend, SparkTimeSeriesBackend): + assert ts.as_pandas() is df + else: + assert ts.as_pandas().equals(df.toPandas()) + + # make sure include_timestamps is working properly + if ts_source is None: + pdf = ts.as_pandas() + assert (pdf.columns == df.columns).all() + + pdf = ts.as_pandas(include_timestamps=False) + assert (pdf.columns == df.columns).all() + + pdf = ts.as_pandas(include_timestamps=True) + assert ( + pdf["timestamp"].values == np.arange(start=0, stop=get_df_len(df)) + ).all() + else: + pdf = ts.as_pandas() + assert (pdf.columns == df.columns).all() + + pdf = ts.as_pandas(include_timestamps=False) + assert pdf.columns.tolist() == [x for x in df.columns if x != ts_source] + + pdf = ts.as_pandas(include_timestamps=True) + assert get_col_list(pdf, ts_source) == get_col_list(df, ts_source) + + # Verify that json serialization round-trips + json_repr = ts.to_json() + + json_round_trip = dm.SingleTimeSeries.from_json(json_repr) + assert check_df_ts_eq(df, json_round_trip, ts_source) + + json_obj = json.loads(json_repr) + # Quick test to make sure that we can ingest json with start_time and period_length not being a pd.Series + if json_obj.get("time_period"): + json_obj["time_period"] = { + "start_time": {"ts_int": 5}, + "period_length": { + "dt_int": 10, + }, + } + + ts_new_period = dm.SingleTimeSeries.from_json(json_obj) + assert ts_new_period.time_period.to_dict() == json_obj["time_period"] + + # static as it never changes here + to_check = [5, 15, 25] + + # this is not a possible case, but checking for completeness + if ( + ts_new_period.timestamp_label is None + or ts_new_period.timestamp_label == "" + ): + assert ts_new_period.as_pandas().index.values.tolist() == to_check + else: + assert ( + ts_new_period.as_pandas()[ + ts_new_period.timestamp_label + ].values.tolist() + == to_check + ) + + # Verify that the pandas view looks the same if not from backend + # assert check_df_ts_eq(ts_new_period.as_pandas(), ts_new_period, ts_source) + + json_obj["time_period"] = { + "start_time": {"ts_epoch": {"seconds": 631195200}}, + "period_length": { + "dt_float": 3600.0, + }, + } + + ts_new_period = dm.SingleTimeSeries.from_json(json_obj) + assert ts_new_period.time_period.to_dict() == json_obj["time_period"] + + # static as it never changes here + to_check = [ + pd.Period(value=dt.datetime.utcfromtimestamp(631195200), freq="H"), + pd.Period( + value=dt.datetime.utcfromtimestamp(631195200 + 3600), freq="H" + ), + pd.Period( + value=dt.datetime.utcfromtimestamp(631195200 + 3600 * 2), freq="H" + ), + ] + + # this is not a possible case, but checking for completeness + if ( + ts_new_period.timestamp_label is None + or ts_new_period.timestamp_label == "" + ): + assert ts_new_period.as_pandas().index.values.tolist() == to_check + else: + assert ( + ts_new_period.as_pandas()[ + ts_new_period.timestamp_label + ].values.tolist() + == to_check + ) + + # Verify that the pandas view looks the same if not from backend + # assert check_df_ts_eq(ts_new_period.as_pandas(), ts_new_period, ts_source) + + # Verify that proto serialization round-trips + proto_repr = ts.to_proto() + proto_round_trip = dm.SingleTimeSeries.from_proto(proto_repr) + assert check_df_ts_eq(df, proto_round_trip, ts_source) + + # Verify that the pandas view looks the same if not from backend + assert check_df_ts_eq(proto_round_trip.as_pandas(), ts, ts_source) + + +@pytest.mark.filterwarnings( + "ignore:If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning" +) +@pytest.mark.parametrize("df_ts_data", testable_data_frames) +def test_timeseries_spark(df_ts_data): + """Tests for TimeSeries objects backed by Pandas data frames. This test is + parametrized over ALL different flavors of data frame since all layouts + should behave the same! + """ + df, ts_source = df_ts_data + + with ensure_spark_cached(df) as df: + # this doesn't for spark dataframes + test_log.debug("Running test_timeseries_spark:\n%s", df) + test_log.debug("ts_source: %s", ts_source) + + ts = dm.SingleTimeSeries(df, timestamp_column=ts_source) + + # Veryify that as_spark returns something the same as we passed in + from_ts = ts.as_spark().toPandas().copy(deep=True) + from_ts.reset_index(drop=True, inplace=True) + from_df = ( + df.toPandas() + if isinstance(ts._backend, SparkTimeSeriesBackend) + else df.copy(deep=True) + ) + from_df.reset_index(drop=True, inplace=True) + from_df = strip_periodic(from_df) + from_df_numpy = from_df.to_numpy() + for idx, from_ts_val in enumerate(from_ts.to_numpy()): + val_ts = ( + from_ts_val.tolist() if hasattr(from_ts_val, "tolist") else from_ts_val + ) + val_df = ( + from_df_numpy[idx].tolist() + if hasattr(from_df_numpy[idx], "tolist") + else from_df_numpy[idx] + ) + np.testing.assert_equal(val_ts, val_df) + # assert val_ts[0] == val_df[0], idx + # assert (val_ts[1] == val_df[1]).all(), idx + + # print(ts.as_spark().toPandas()) + # print(df_project(df)) + + # make sure include_timestamps is working properly + dftocompare = ( + df.toPandas() + if isinstance(df, pyspark.sql.DataFrame) + else strip_periodic(df, create_copy=True) + ) + if ts_source is None: + pdf = ts.as_spark().toPandas() + assert (pdf.columns == dftocompare.columns).all() + + pdf = ts.as_spark(include_timestamps=False).toPandas() + assert (pdf.columns == dftocompare.columns).all() + + pdf = ts.as_spark(include_timestamps=True).toPandas() + assert ( + pdf["timestamp"].values + == np.arange(start=0, stop=get_df_len(dftocompare)) + ).all() + else: + pdf = ts.as_spark().toPandas() + assert (pdf.columns == dftocompare.columns).all() + + pdf = ts.as_spark(include_timestamps=False).toPandas() + assert pdf.columns.tolist() == [ + x for x in dftocompare.columns if x != ts_source + ] + + pdf = ts.as_spark(include_timestamps=True).toPandas() + assert get_col_list(pdf, ts_source) == get_col_list(dftocompare, ts_source) + + +def test_timeseries_raises_on_bad_input(): + # Local + import caikit + + with pytest.raises(NotImplementedError): + ts = dm.SingleTimeSeries([]) + + class Dummy: + def to_list(self): + return [] + + assert [] == iteritems_workaround(Dummy(), force_list=False) + + caikit.interfaces.ts.data_model._single_timeseries.HAVE_PYSPARK = False + + df = pd.DataFrame([1, 2, 3]) + ts = dm.SingleTimeSeries(df) + with pytest.raises(NotImplementedError): + ts.as_spark() + + caikit.interfaces.ts.data_model._single_timeseries.HAVE_PYSPARK = True diff --git a/tests/interfaces/ts/data_model/test_timeseries.py b/tests/interfaces/ts/data_model/test_timeseries.py new file mode 100644 index 000000000..b66890653 --- /dev/null +++ b/tests/interfaces/ts/data_model/test_timeseries.py @@ -0,0 +1,832 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the Timeseries data model object +""" + +# Standard +from datetime import datetime, timezone +from typing import Iterable, Union +import datetime as dt +import json +import os +import traceback +import warnings + +# Third Party +from pandas import RangeIndex +import dateutil +import numpy as np +import pandas as pd +import pyspark +import pytest + +# Local +from caikit.core.data_model import ProducerId +from caikit.interfaces.ts.data_model import SingleTimeSeries +from caikit.interfaces.ts.data_model.backends._spark_backends import ensure_spark_cached +from caikit.interfaces.ts.data_model.backends.spark_util import iteritems_workaround +from caikit.interfaces.ts.data_model.backends.util import ( + pd_timestamp_to_seconds, + strip_periodic, +) +from tests.interfaces.ts.data_model.util import create_extended_test_dfs, df_project +from tests.interfaces.ts.helpers import sslocal_fixture, test_log +import caikit.interfaces.ts.data_model as dm + +warnings.filterwarnings("ignore", category=ResourceWarning) + +test_log.setLevel("DEBUG") + + +keys = [["a", "a", "b"], ["c", "d", "e"]] + +key_cols = {f"key_{i}": keys[i] for i in range(2)} + +# Standard reusable value columns +value_cols = [range(3), np.arange(0, 1.5, 0.5)] +value_cols_dict = ({f"val_{i}": val_col for i, val_col in enumerate(value_cols)}, float) +value_rows = (list(zip(*value_cols)), float) +value_rows_str = ([(x[0], f"value:{x[1]}") for x in value_rows[0]], "string") +value_rows_any = ([(x[0], f"value:{x[1]}") for x in value_rows[0]], object) +value_rows_list = ([(x[0], [x[1], x[1], x[1]]) for x in value_rows[0]], object) +value_rows_period = ( + [ + ( + x[0], + dt.datetime(year=2022, month=1, day=int(round(x[1] + 1))), + ) + for x in value_rows[0] + ], + "datetime64[ns]", +) + + +# Timestamp range types +int_numeric_range = range(0, 30, 10) +float_numeric_range = np.arange(0, 1.5, 0.5) +date_range_regular = pd.date_range("2000", freq="D", periods=3) +date_range_irregular = pd.date_range("2000", freq="B", periods=3) +period_range_regular = pd.period_range("2000", freq="D", periods=3) +period_range_irregular = pd.period_range("2000", freq="B", periods=3) + +# Reusable timestamp column name +# NOTE: This is _intentionally_ not the same as the default! +default_ts_col_name = "ts" + +# All testable data frame configurations! +testable_data_frames = [] +for ts_range in [ + int_numeric_range, + float_numeric_range, + date_range_regular, + date_range_irregular, + period_range_regular, + period_range_irregular, +]: + # Data column types + for data_arg, data_type in [ + value_rows, + value_cols_dict, + value_rows_str, + value_rows_any, + value_rows_period, + value_rows_list, + ]: + cur_df = pd.DataFrame(data_arg) + + for k, v in key_cols.items(): + cur_df[k] = v + # be explicit about type otherwise goes to Any + cur_df = cur_df.astype({cur_df.columns[1]: data_type}) + cur_df.columns = cur_df.columns.astype(str) + + testable_data_frames.append((cur_df, None, list(key_cols.keys())[0], None)) + for i in range(len(key_cols)): + k = list(key_cols.keys())[: i + 1] + # Add df w/ ts in index + testable_data_frames.append((cur_df, None, k, None)) + + # Add df w/ ts in column + if isinstance(data_arg, dict): + full_data_arg = dict(**{default_ts_col_name: ts_range}, **data_arg) + ts_col_name = default_ts_col_name + else: + full_data_arg = [[ts] + list(vals) for ts, vals in zip(ts_range, data_arg)] + ts_col_name = "0" + cur_df = pd.DataFrame(full_data_arg) + + for k, v in key_cols.items(): + cur_df[k] = v + cur_df.columns = cur_df.columns.astype(str) + testable_data_frames.append( + (cur_df, ts_col_name, list(key_cols.keys())[0], None) + ) + + # value column is specified + if isinstance(data_arg, dict): + value_keys = list(data_arg.keys()) + value_keys.append(None) + for value_key in value_keys: + for i in range(len(key_cols)): + k = list(key_cols.keys())[: i + 1] + # Add df w/ ts in index + testable_data_frames.append((cur_df, ts_col_name, k, value_key)) + # value column is unspecified + else: + for i in range(len(key_cols)): + k = list(key_cols.keys())[: i + 1] + # Add df w/ ts in index + testable_data_frames.append((cur_df, ts_col_name, k, None)) + + # just a simple test to include int keys + cur_df["key_int_1"] = np.array([1, 1, 3], dtype=np.int32) + cur_df["key_int_2"] = np.array([4, 5, 6], dtype=np.int32) + testable_data_frames.append( + (cur_df, ts_col_name, ["key_int_1", "key_int_2"], None) + ) + +# replicate and extended the dataframes with pyspark.sql.DataFrame if needed +testable_data_frames = create_extended_test_dfs(testable_data_frames) + + +def get_df_len(df_in): + if isinstance(df_in, pd.DataFrame): + return len(df_in) + else: + return len(df_in.toPandas()) + + +def get_col_list(df_in, col): + if isinstance(df_in, pd.DataFrame): + return df_in[col].values.tolist() + else: + return df_in.toPandas()[col].values.tolist() + + +@pytest.mark.filterwarnings( + "ignore:'PYARROW_IGNORE_TIMEZONE' environment variable was not set.*", + "ignore:`to_list` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:`to_numpy` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:In a future version of pandas, a length 1 tuple will be returned.*:", +) +@pytest.mark.parametrize("df_mts_data", testable_data_frames) +def test_timeseries_spark(df_mts_data): + """Subset of test_timeseries_pd to exercise to_spark functionality of multi_timeseries""" + df, ts_source, key_source, value_source = df_mts_data + + with ensure_spark_cached(df) as df: + value_source = None if value_source is None else [value_source] + + mts = dm.TimeSeries( + df, + key_column=key_source, + timestamp_column=ts_source, + value_columns=value_source, + ) + + # make sure include_timestamps is working properly + if ts_source is None: + pdf = mts.as_spark().toPandas() + assert (pdf.columns == df_project(df).columns).all() + + pdf = mts.as_spark(include_timestamps=False).toPandas() + assert (pdf.columns == df_project(df).columns).all() + + pdf = mts.as_spark(include_timestamps=True).toPandas() + for _, group in pdf.groupby(key_source): + assert ( + group["timestamp"].values + == np.arange(start=0, stop=get_df_len(group)) + ).all() + else: + pdf = mts.as_spark().toPandas() + assert pdf.columns.tolist() == df_project(df).columns.tolist() + pdf = mts.as_spark(include_timestamps=False).toPandas() + assert pdf.columns.tolist() == [x for x in df.columns if x != ts_source] + pdf = mts.as_spark(include_timestamps=True).toPandas() + assert get_col_list(pdf, ts_source) == get_col_list( + strip_periodic(df, create_copy=True), ts_source + ) + + +@pytest.mark.filterwarnings( + "ignore:'PYARROW_IGNORE_TIMEZONE' environment variable was not set.*", + "ignore:`to_list` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:`to_numpy` loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:In a future version of pandas, a length 1 tuple will be returned.*:", +) +@pytest.mark.slow +@pytest.mark.parametrize("df_mts_data", testable_data_frames) +def test_timeseries_pd(df_mts_data): + """Tests for TimeSeries objects backed by Pandas data frames. This test is + parametrized over ALL different flavors of data frame since all layouts + should behave the same! + """ + df, ts_source, key_source, value_source = df_mts_data + + with ensure_spark_cached(df) as df: + test_log.debug("Running test_timeseries_pd:\n%s", df) + test_log.debug("ts_source: %s", ts_source) + test_log.debug("key_source: %s", key_source) + + value_source = None if value_source is None else [value_source] + + mts = dm.TimeSeries( + df, + key_column=key_source, + timestamp_column=ts_source, + value_columns=value_source, + ) + + if isinstance(df, pd.DataFrame): + assert mts.as_pandas() is df + elif isinstance(df, pyspark.sql.DataFrame): + # best we can do is this + assert mts.as_pandas().equals(df.toPandas()) + else: + ... + + # make sure include_timestamps is working properly + if ts_source is None: + pdf = mts.as_pandas() + assert (pdf.columns == df.columns).all() + + pdf = mts.as_pandas(include_timestamps=False) + assert (pdf.columns == df.columns).all() + + pdf = mts.as_pandas(include_timestamps=True) + for _, group in pdf.groupby(key_source): + assert ( + group["timestamp"].values + == np.arange(start=0, stop=get_df_len(group)) + ).all() + else: + pdf = mts.as_pandas() + assert (pdf.columns == df.columns).all() + + pdf = mts.as_pandas(include_timestamps=False) + assert pdf.columns.tolist() == [x for x in df.columns if x != ts_source] + + pdf = mts.as_pandas(include_timestamps=True) + assert get_col_list(pdf, ts_source) == get_col_list(df, ts_source) + + # note: Added this check in to speed up CI build as we still get 100% coverage without it. + # having said that, this should be added to a cron-job as part of a nightly build + if ( + isinstance(df, pd.DataFrame) + or os.getenv("RUN_SPARKDF_SLOW_TESTS", "0") == "1" + ): + # Verify that json serialization round-trips + json_repr = mts.to_json() + json_round_trip = dm.TimeSeries.from_json(json_repr) + try: + assert check_df_mts_eq(df, json_round_trip, ts_source, key_source) + + # Verify that proto serialization round-trips + proto_repr = mts.to_proto() + proto_round_trip = dm.TimeSeries.from_proto(proto_repr) + assert check_df_mts_eq(df, proto_round_trip, ts_source, key_source) + + # Verify that the original source can convert properly + assert check_df_mts_eq(df, mts, ts_source, key_source) + except: + traceback.print_exc() + assert False + + +def get_ts_sequence( + df: pd.DataFrame, + ts_source: Union[str, int], + key_source: Union[str, Iterable[str], None], + ids: Union[Iterable[str], None], +) -> pd.Series: + df_new = df_project(df).copy() + if key_source is not None: + if isinstance(key_source, str): + key_source = [key_source] + for i in range(len(key_source)): + df_new = df_new[df_new[key_source[i]] == ids[i]] + """Helper to pull the sequence based on where the source is""" + return ( + RangeIndex(start=0, stop=df_project(df).shape[0], step=1) + if ts_source is None + else iteritems_workaround(df_new[ts_source]) + ) + + +def check_df_mts_eq( + df: pd.DataFrame, + mts: dm.TimeSeries, + ts_source: Union[str, int], + key_source: Union[str, Iterable[str]], +) -> bool: + # test internal data + for ts in mts.timeseries: + res = check_df_ts_eq(df, ts, ts_source, key_source) + if not res: + return False + + return True + + +def check_df_ts_eq( + df: pd.DataFrame, + datamodel_ts: SingleTimeSeries, + ts_source: Union[str, int], + key_source, +) -> bool: + """Helper to make sure the actual data in the data frame and the TimeSeries + line up + """ + + ################### + ## Time Sequence ## + ################### + + ts_from_df = get_ts_sequence(df, ts_source, key_source, datamodel_ts.ids.values) + if not datamodel_ts.time_sequence: + test_log.debug("No valid time sequence!") + return False + if isinstance(ts_from_df.dtype, pd.PeriodDtype): + # If it's a periodic index, the timeseries may hold this as either a + # PeriodicTimeSequence (if the freq is regular) or a PointTimeSequence + # (if the freq is irregular) + if datamodel_ts.time_period: + if not datamodel_ts.time_period.start_time.ts_epoch: + test_log.debug("Start time for periodic not based in the epoch") + return False + if ( + datamodel_ts.time_period.start_time.ts_epoch.as_datetime().timestamp() + != ts_from_df.iloc[0].start_time.timestamp() + ): + test_log.debug( + "Periodic time sequence start time mismatch: %s != %s", + datamodel_ts.time_period.start_time.ts_epoch.as_datetime(), + ts_from_df[0].start_time, + ) + return False + + # The period may either be a string (pandas period notation) or a + # number of seconds + if datamodel_ts.time_period.period_length.dt_str: + if ( + datamodel_ts.time_period.period_length.dt_str + != ts_from_df.dtype.freq.name + ): + test_log.debug( + "Period str duration mismatch: %s != %s", + datamodel_ts.time_period.period_length.dt_str, + ts_from_df.dtype.freq.name, + ) + return False + + elif not datamodel_ts.time_period.period_length.dt_sec: + test_log.debug("Period length for periodic not in seconds or str") + return False + elif ( + datamodel_ts.time_period.period_length.dt_sec.as_timedelta() + != ts_from_df.dtype.freq.delta + ): + test_log.debug( + "Period length mismatch: %s != %s", + datamodel_ts.time_period.period_length.dt_sec.as_timedelta(), + ts_from_df.dtype.freq.delta, + ) + return False + elif isinstance(ts_from_df, RangeIndex): + if datamodel_ts.time_period.start_time.ts_int is None: + test_log.debug("Start time for periodic not based in the int") + return False + if datamodel_ts.time_period.start_time.ts_int != ts_from_df.start: + test_log.debug( + "Periodic time sequence start time mismatch: %s != %s", + datamodel_ts.time_period.start_time.ts_int, + ts_from_df.start, + ) + return False + + # The period may either be a string (pandas period notation) or a + # number of seconds + if datamodel_ts.time_period.period_length.dt_int is not None: + if datamodel_ts.time_period.period_length.dt_int != ts_from_df.step: + test_log.debug( + "Period int duration mismatch: %s != %s", + datamodel_ts.time_period.period_length.dt_int, + ts_from_df.step, + ) + return False + # If not a periodic index, the dm representation is a sequence of points + else: + if not datamodel_ts.time_points: + test_log.debug("Sequential sequence not represented as points") + return False + + # Make sure the appropriate point types are used + if len(datamodel_ts.time_points.points) != len(ts_from_df): + test_log.debug( + "Time point length mismatch: %d != %d", + len(datamodel_ts.time_points.points), + len(ts_from_df), + ) + return False + + # Compare point values. We use view_point.time which will pull the + # appropriate backing point type + for i, (datamodel_point, df_point) in enumerate( + zip(datamodel_ts.time_points.points, ts_from_df) + ): + test_log.debug( + "Comparing TimePoints of type %s / %s", + type(datamodel_point.time), + type(df_point), + ) + datamodel_time = datamodel_point.time + if isinstance(datamodel_time, dm.Seconds): + datamodel_time = datamodel_time.as_datetime() + # direct comparison of datetime and np.datetime64 objects + # are fraught with ambiguity. Consider this example: + # dt = datetime(year=2000, month=1, day=1, second=0, microsecond=0) + # npdt = np.datetime64(dt.isoformat()) + # npdt == dt # True + # np.datetime64("2000-01-01T00:00:00.000000000") == dt # False ! + # np.datetime64("2000-01-01T00:00:00.000000000") == npdt # True ! + # Confusing to say the least + datamodel_seconds = pd_timestamp_to_seconds(datamodel_time) + df_seconds = pd_timestamp_to_seconds(df_point) + if datamodel_seconds != df_seconds: + test_log.debug( + "Point value mismatch: %s != %s delta is %s", + datamodel_seconds, + df_seconds, + datamodel_seconds - df_seconds, + ) + return False + + ############ + ## Values ## + ############ + + df_val_cols = [ + val_label if val_label in df.columns else int(val_label) + for val_label in datamodel_ts.value_labels or df.columns + ] + test_log.debug("df_val_cols: %s", df_val_cols) + if len(df_val_cols) != len(datamodel_ts.values): + test_log.debug("Value labels and value columns have mismatched length") + return False + + for df_val_col_key, ts_val_seq in zip(df_val_cols, datamodel_ts.values): + ts_vals = list(ts_val_seq.sequence.values) + ids = datamodel_ts.ids.values + df_new = df_project(df).copy() + if isinstance(key_source, str): + key_source = [key_source] + for i in range(len(key_source)): + df_new = df_new[df_new[key_source[i]] == ids[i]] + df_val_col = df_new[df_val_col_key] + if len(df_val_col) != len(ts_vals): + test_log.debug("Column %s has length mismatch", df_val_col_key) + return False + + # TODO: what about Any? + # We currently give back the serialized version when values is called, but should it be the deserialized??? + np_value_col = df_val_col.to_numpy() + if ts_val_seq.val_any is not None: + ts_vals = [json.loads(v) for v in ts_vals] + if ts_val_seq.val_timepoint is not None: + ts_vals = [np.datetime64(dateutil.parser.parse(v)) for v in ts_vals] + + # we have to test each separately since each is a vector + if ts_val_seq.val_vector is not None: + ts_vals = [v for v in ts_vals] + if not len(np_value_col) == len(ts_vals): + test_log.debug("vector lengths didn't match") + return False + for idx, ts_vals_val in enumerate(ts_vals): + to_check = ( + np_value_col[idx].tolist() + if isinstance(np_value_col[idx], np.ndarray) + else np_value_col[idx] + ) + if not to_check == ( + ts_vals_val.tolist() + if hasattr(ts_vals_val, "tolist") + else ts_vals_val + ): + test_log.debug( + "Column %s has value mismatch: %s != %s", + df_val_col_key, + df_val_col, + ts_vals_val, + ) + return False + else: + if not (np_value_col == ts_vals).all(): + test_log.debug( + "Column %s has value mismatch: %s != %s", + df_val_col_key, + df_val_col, + ts_vals, + ) + return False + + return True + + +def _cmp(it1, it2): + # oh the joys of having a language with no types + try: + np.testing.assert_equal( + [pd_timestamp_to_seconds(x) for x in it1], + [pd_timestamp_to_seconds(x) for x in it2], + ) + return True + except AttributeError: + ... + except ValueError: + ... + except AssertionError: + return False + + if hasattr(it1, "to_numpy") and hasattr(it2, "to_numpy"): + return (it1.to_numpy() == it2.to_numpy()).all() + else: + return (it1 == it2).all() + + +def compare_np( + df: pd.DataFrame, + np_view: np.ndarray, + ts_source: Union[str, int], + key_source: Union[str, Iterable[str]], + value_source: Iterable[str], + data_model_columns, # columns in the data model. This is required as if we are coming from json or some other source, we will not have all of the columns pertaining to the original dataframe +) -> bool: + """Compare the output numpy view to the input data frame. The following + conventions should be true: + + 1. The first column of the ndarray should be the time sequence + 2. The ndarray's dtype should be the "lowest common denominator" of the time + sequence and value columns (e.g. object < float < int) + """ + ts_range = get_ts_sequence(df, ts_source, None, None) + + # Make sure the time sequence on timestamp matches + if isinstance(key_source, str): + key_source = [key_source] + + if isinstance(value_source, str): + value_source = [value_source] + + # ordering when dealing with partititioned spark dataframes is + # not guaranteed so we'll need to do this: + + np_view_as_pandas = pd.DataFrame( + columns=data_model_columns, data=np_view + ).sort_values(by=key_source + [ts_source]) + if not _cmp(ts_range, np_view_as_pandas[ts_source]): + test_log.debug( + "Numpy ts sequence mismatch: %s != %s", + ts_range, + np_view_as_pandas[ts_source], + ) + return False + + val_cols = [ + col + for col in data_model_columns + if col != ts_source and col in key_source or col in value_source + ] + np_val_rows = np_view_as_pandas[val_cols].to_numpy() + + if not _cmp( + np_view_as_pandas[val_cols].to_numpy(), df_project(df)[val_cols].to_numpy() + ): + test_log.debug( + "NP view data mismatch: %s != %s", np_val_rows, df_project(df)[value_source] + ) + return False + + return True + + +def test_pd_timestamp_to_seconds(): + adate = datetime(2000, 1, 1, tzinfo=timezone.utc) + assert pd_timestamp_to_seconds(pd.Period(adate.isoformat())) == adate.timestamp() + assert ( + pd_timestamp_to_seconds(np.datetime64(datetime(2000, 1, 1, 0, 0, 0, 0))) + == adate.timestamp() + ) + assert pd_timestamp_to_seconds(adate) == adate.timestamp() + assert pd_timestamp_to_seconds(adate.timestamp()) == adate.timestamp() + assert pd_timestamp_to_seconds(pd.Timestamp(adate.isoformat())) == adate.timestamp() + with pytest.raises(Exception): + pd_timestamp_to_seconds([]) + + +@pytest.fixture(scope="module") +def trivial_pandas_df(): + return pd.DataFrame(columns=["a", "b", "c"], data=[[1, 2, 3], [1, 4, 5]]) + + +@pytest.fixture(scope="module") +def trivial_spark_df(trivial_pandas_df, sslocal_fixture): + return sslocal_fixture.createDataFrame(trivial_pandas_df) + + +def test_multi_timeseries_raises_on_bad_input(trivial_pandas_df): + # Local + import caikit + + caikit.interfaces.ts.data_model.timeseries.HAVE_PYSPARK = False + df = trivial_pandas_df + ts = dm.TimeSeries(df, key_column="a") + with pytest.raises(NotImplementedError): + ts.as_spark() + caikit.interfaces.ts.data_model.timeseries.HAVE_PYSPARK = True + + with pytest.raises(NotImplementedError): + iteritems_workaround("foobar") + + +# this method could be called internally, we just want to guard for that +def test_multi_timeseries_bad_attribute(trivial_pandas_df): + df = trivial_pandas_df + ts = dm.TimeSeries(df, key_column="a") + + with pytest.raises(ValueError): + ts._backend.get_attribute(ts, "bad_attribute") + + +def test_multi_timeseries_spark_bad_attribute(trivial_spark_df): + df = trivial_spark_df + ts = dm.TimeSeries( + df, + key_column="a", + ) + + with pytest.raises(ValueError): + ts._backend.get_attribute(ts, "bad_attribute") + + +def test_as_spark_with_str_key_cols(trivial_spark_df): + df = trivial_spark_df + ts = dm.TimeSeries( + df, + key_column="a", + ) + p1 = ts.as_spark(include_timestamps=True).toPandas() + p2 = ts.as_pandas(include_timestamps=True) + assert (p1.to_numpy() == p2.to_numpy()).all() + + +def test_as_spark_with_producer_id(trivial_spark_df): + df = trivial_spark_df + ts = dm.TimeSeries( + df, + key_column="a", + producer_id=ProducerId("Test", "1.2.3"), + ) + + assert ts.producer_id.name == "Test" + assert ts.producer_id.version == "1.2.3" + + +def test_mts_len(sslocal_fixture): + df = pd.concat( + [ + pd.DataFrame( + [(x, "A", x * 5, x * 1.333) for x in range(10)], + columns=["ts", "key", "val", "val2"], + ), + pd.DataFrame( + [(x, "B", x * 5, x * 1.333) for x in range(30)], + columns=["ts", "key", "val", "val2"], + ), + ] + ) + + # spark + mts = dm.TimeSeries( + sslocal_fixture.createDataFrame(df), + timestamp_column="ts", + key_column="key", + ) + + assert len(mts) == 40 + + # pandas + mts = dm.TimeSeries( + df, + timestamp_column="ts", + key_column="key", + ) + + assert len(mts) == 40 + + # no backend + mts = dm.TimeSeries( + df, + timestamp_column="ts", + ) + + assert len(mts) == 40 + + +@pytest.mark.filterwarnings( + "ignore:.*loads all data into the driver's memory.*:pyspark.pandas.utils.PandasAPIOnSparkAdviceWarning", + "ignore:toPandas attempted Arrow optimization.*:UserWarning", +) +def test_dm_serializes_spark_vectors(sslocal_fixture): + # Standard + from datetime import datetime + + # Third Party + from pyspark.ml.linalg import Vectors, VectorUDT + from pyspark.sql.types import ( + ArrayType, + DateType, + StringType, + StructField, + StructType, + ) + + v = Vectors.dense([1.0, 2.0]) + schema = StructType( + [ + StructField("date", DateType(), True), + StructField("id", StringType(), True), + StructField("value", VectorUDT(), True), + ] + ) + + data = [ + (datetime(year=2020, month=1, day=1), "id", v), + ] + df = sslocal_fixture.createDataFrame(data=data, schema=schema) + + mts = dm.TimeSeries( + df, + key_column="id", + timestamp_column="date", + value_columns=["value"], + ) + + # test round trip + json_str = mts.to_json() + mts2 = dm.TimeSeries.from_json(json_str) + assert json_str == mts2.to_json() + + +def test_ts_eq(): + """Test time series equivalence""" + df = pd.concat( + [ + pd.DataFrame( + [("a", x, x * 5) for x in range(20)], columns=["id", "ts", "val"] + ), + pd.DataFrame( + [("b", x, x * 5) for x in range(30)], columns=["id", "ts", "val"] + ), + ], + axis=0, + ) + + mts = dm.TimeSeries(df, key_column=["id"], timestamp_column="ts") + mts_a = dm.TimeSeries(df[df.id == "a"], key_column=["id"], timestamp_column="ts") + mts_missing_time = dm.TimeSeries( + df[df.ts < 20], key_column=["id"], timestamp_column="ts" + ) + + # null is equal + assert dm.TimeSeries(pd.DataFrame()) == dm.TimeSeries(pd.DataFrame()) + + # trivially equal + assert mts == mts + assert mts_a == mts_a + + # number of ids different + assert mts != mts_a + + # same number of ids, but different ones + df_c = df.copy() + df_c.loc[df_c["id"] == "b", "id"] = "c" + mts_c = dm.TimeSeries(df_c, key_column=["id"], timestamp_column="ts") + + assert mts != mts_c + + # missing time points + assert mts != mts_missing_time diff --git a/tests/interfaces/ts/data_model/test_timeseries_evaluation.py b/tests/interfaces/ts/data_model/test_timeseries_evaluation.py new file mode 100644 index 000000000..d379ce6b4 --- /dev/null +++ b/tests/interfaces/ts/data_model/test_timeseries_evaluation.py @@ -0,0 +1,260 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the Timeseries Evaluator data model object +""" +# Standard +import json +import warnings + +# Third Party +from pandas.testing import assert_frame_equal +import pandas as pd +import pytest + +# Local +from caikit.core.data_model import ProducerId +import caikit.interfaces.ts.data_model as dm + +warnings.filterwarnings("ignore", category=ResourceWarning) + + +@pytest.fixture(scope="module") +def eval_df_wo_offset(): + """Simple pandas df for testing target generation on multi-time series dataframe""" + + cv_pd_df_wo_offset = pd.DataFrame( + { + "prediction_mse": [0.013289, 0.010335, 1.099107], + "prediction_mae": [0.115280, 0.101663, 0.841607], + "prediction_smape": [0.010144, 0.006307, 0.014446], + } + ) + + yield cv_pd_df_wo_offset + + +@pytest.fixture(scope="module") +def eval_df_w_offset(): + cv_pd_df_w_offset = pd.DataFrame( + { + "ID1": ["A", "B", "C"], + "ID2": ["D", "E", "F"], + "prediction_mse": [0.013289, 0.010335, 1.099107], + "prediction_mae": [0.115280, 0.101663, 0.841607], + "prediction_smape": [0.010144, 0.006307, 0.014446], + "offset": ["overall", 0, 1], + } + ) + + yield cv_pd_df_w_offset + + +def test_Id_dm(eval_df_w_offset): + cv_pd_df_w_offset = eval_df_w_offset + + id_values = cv_pd_df_w_offset["offset"].tolist() + + id_value_dms = [dm.Id(id_value) for id_value in id_values] + + id_value_dm = dm.Id.from_proto(id_value_dms[0].to_proto()) + assert id_value_dm.text == id_values[0] + + id_value_dm = dm.Id.from_proto(id_value_dms[1].to_proto()) + assert id_value_dm.index == id_values[1] + + id_value_dm = dm.Id.from_json(id_value_dms[0].to_json()) + assert id_value_dm.text == id_values[0] + + id_value_dm = dm.Id.from_json(json.loads(id_value_dms[0].to_json())) + assert id_value_dm.text == id_values[0] + + id_value_dm = dm.Id.from_json(id_value_dms[2].to_json()) + assert id_value_dm.index == id_values[2] + + +def test_EvaluationRecord_dm(eval_df_wo_offset, eval_df_w_offset): + cv_pd_df_wo_offset, cv_pd_df_w_offset = eval_df_wo_offset, eval_df_w_offset + + sample_EvaluationRecords = cv_pd_df_wo_offset.T.values.T.tolist() + + ER_dm = dm.EvaluationRecord(metric_values=sample_EvaluationRecords[1]) + + assert len(ER_dm.id_values) == 0 + assert pytest.approx(ER_dm.metric_values[1]) == 0.101663 + assert ER_dm.offset == None + + ER_dm_rndTrip = dm.EvaluationRecord.from_proto(ER_dm.to_proto()) + assert len(ER_dm.id_values) == 0 + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.101663 + assert ER_dm_rndTrip.offset == None + + ER_dm_rndTrip = dm.EvaluationRecord.from_json(ER_dm.to_json()) + assert len(ER_dm.id_values) == 0 + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.101663 + assert ER_dm_rndTrip.offset == None + + sample_EvaluationRecords = cv_pd_df_w_offset.T.values.T.tolist() + + ER_dm = dm.EvaluationRecord( + sample_EvaluationRecords[0][:2], + sample_EvaluationRecords[0][2:5], + sample_EvaluationRecords[0][5], + ) + assert ER_dm.id_values[0].text == "A" + assert ER_dm.id_values[1].text == "D" + assert pytest.approx(ER_dm.metric_values[1]) == 0.11528 + assert ER_dm.offset.text == "overall" + + ER_dm_rndTrip = dm.EvaluationRecord.from_proto(ER_dm.to_proto()) + assert ER_dm_rndTrip.id_values[0].text == "A" + assert ER_dm_rndTrip.id_values[1].text == "D" + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.11528 + assert ER_dm_rndTrip.offset.text == "overall" + + ER_dm_rndTrip = dm.EvaluationRecord.from_json(ER_dm.to_json()) + assert ER_dm_rndTrip.id_values[0].text == "A" + assert ER_dm_rndTrip.id_values[1].text == "D" + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.11528 + assert ER_dm_rndTrip.offset.text == "overall" + + ER_dm = dm.EvaluationRecord( + sample_EvaluationRecords[1][:2], + sample_EvaluationRecords[1][2:5], + sample_EvaluationRecords[1][5], + ) + + assert ER_dm.id_values[0].text == "B" + assert ER_dm.id_values[1].text == "E" + assert pytest.approx(ER_dm.metric_values[1]) == 0.101663 + assert ER_dm.offset.index == 0 + + ER_dm_rndTrip = dm.EvaluationRecord.from_proto(ER_dm.to_proto()) + assert ER_dm_rndTrip.id_values[0].text == "B" + assert ER_dm_rndTrip.id_values[1].text == "E" + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.101663 + assert ER_dm_rndTrip.offset.index == 0 + + ER_dm_rndTrip = dm.EvaluationRecord.from_json(ER_dm.to_json()) + assert ER_dm_rndTrip.id_values[0].text == "B" + assert ER_dm_rndTrip.id_values[1].text == "E" + assert pytest.approx(ER_dm_rndTrip.metric_values[1]) == 0.101663 + assert ER_dm_rndTrip.offset.index == 0 + + +def test_EvaluationResult_w_offset(eval_df_w_offset): + cv_pd_df_w_offset = eval_df_w_offset + + ER_dm = dm.EvaluationResult( + id_cols=["ID1", "ID2"], + metric_cols=["prediction_mse", "prediction_mae", "prediction_smape"], + offset_col="offset", + df=cv_pd_df_w_offset, + producer_id=ProducerId("Test", "1.0.0"), + ) + assert ER_dm.metric_cols == ["prediction_mse", "prediction_mae", "prediction_smape"] + assert ER_dm.id_cols == ["ID1", "ID2"] + assert ER_dm.producer_id.name == "Test" + assert ER_dm.producer_id.version == "1.0.0" + + pdf = ER_dm.as_pandas() + assert_frame_equal(pdf, cv_pd_df_w_offset) + assert (pdf.columns == cv_pd_df_w_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_w_offset.columns] + + ER_dm_to_proto = ER_dm.to_proto() + ER_dm_rndTrip = dm.EvaluationResult.from_proto(ER_dm_to_proto) + assert ER_dm_rndTrip.records[0].id_values[0].text == "A" + assert pytest.approx(ER_dm_rndTrip.records[0].metric_values[0]) == 0.013289 + assert ER_dm_rndTrip.id_cols[0] == "ID1" + + pdf = dm.EvaluationResult.from_proto(ER_dm.to_proto()).as_pandas() + assert_frame_equal(pdf, cv_pd_df_w_offset) + assert (pdf.columns == cv_pd_df_w_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_w_offset.columns] + + ER_dm_to_json = json.loads(ER_dm.to_json()) + assert len(ER_dm_to_json["id_cols"]) == 2 + assert len(ER_dm_to_json["metric_cols"]) == 3 + assert ER_dm_to_json["offset_col"] == "offset" + + ER_dm_rndTrip = dm.EvaluationResult.from_json(ER_dm_to_json) + assert ER_dm_rndTrip.records[0].id_values[0].text == "A" + assert pytest.approx(ER_dm_rndTrip.records[0].metric_values[0]) == 0.013289 + assert ER_dm_rndTrip.id_cols[0] == "ID1" + + pdf = dm.EvaluationResult.from_json(ER_dm.to_json()).as_pandas() + assert_frame_equal(pdf, cv_pd_df_w_offset) + assert (pdf.columns == cv_pd_df_w_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_w_offset.columns] + + +def test_EvaluationResult_wo_offset(eval_df_wo_offset): + cv_pd_df_wo_offset = eval_df_wo_offset + + ER_dm = dm.EvaluationResult( + metric_cols=["prediction_mse", "prediction_mae", "prediction_smape"], + df=cv_pd_df_wo_offset, + ) + + assert ER_dm.metric_cols == ["prediction_mse", "prediction_mae", "prediction_smape"] + assert len(ER_dm.id_cols) == 0 + + pdf = ER_dm.as_pandas() + assert_frame_equal(pdf, cv_pd_df_wo_offset) + assert (pdf.columns == cv_pd_df_wo_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_wo_offset.columns] + + ER_dm_to_proto = ER_dm.to_proto() + ER_dm_rndTrip = dm.EvaluationResult.from_proto(ER_dm_to_proto) + assert len(ER_dm_rndTrip.records[0].id_values) == 0 + assert pytest.approx(ER_dm_rndTrip.records[0].metric_values[0]) == 0.013289 + assert len(ER_dm_rndTrip.id_cols) == 0 + + pdf = dm.EvaluationResult.from_proto(ER_dm.to_proto()).as_pandas() + assert_frame_equal(pdf, cv_pd_df_wo_offset) + assert (pdf.columns == cv_pd_df_wo_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_wo_offset.columns] + + ER_dm_to_json = json.loads(ER_dm.to_json()) + assert len(ER_dm_to_json["id_cols"]) == 0 + assert len(ER_dm_to_json["metric_cols"]) == 3 + assert ER_dm_to_json["offset_col"] == None + + ER_dm_rndTrip = dm.EvaluationResult.from_json(ER_dm_to_json) + assert len(ER_dm_rndTrip.records[0].id_values) == 0 + assert pytest.approx(ER_dm_rndTrip.records[0].metric_values[0]) == 0.013289 + assert len(ER_dm_rndTrip.id_cols) == 0 + + pdf = dm.EvaluationResult.from_json(ER_dm.to_json()).as_pandas() + assert_frame_equal(pdf, cv_pd_df_wo_offset) + assert (pdf.columns == cv_pd_df_wo_offset.columns).all() + assert pdf.columns.tolist() == [x for x in cv_pd_df_wo_offset.columns] + + +def test_Errors(eval_df_w_offset): + cv_pd_df_w_offset = eval_df_w_offset + + id_values = cv_pd_df_w_offset["offset"].tolist() + # id_value_dms = [dm.Id(id_value) for id_value in id_values] + + sample_EvaluationRecords = cv_pd_df_w_offset.T.values.T.tolist() + ER_dm = dm.EvaluationRecord( + sample_EvaluationRecords[0][:2], + sample_EvaluationRecords[0][2:5], + sample_EvaluationRecords[0][5], + ) + + with pytest.raises(ValueError): + dm.Id.from_proto(ER_dm.to_proto()) diff --git a/tests/interfaces/ts/data_model/util.py b/tests/interfaces/ts/data_model/util.py new file mode 100644 index 000000000..de5ce90d3 --- /dev/null +++ b/tests/interfaces/ts/data_model/util.py @@ -0,0 +1,96 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities used in data_model tests""" + +# Standard +from typing import Iterable, List, Tuple +import copy +import os + +# Third Party +import numpy as np +import pandas as pd +import pyspark + +# Local +from caikit.interfaces.ts.data_model.backends.util import strip_periodic +from tests.interfaces.ts.helpers import sslocal + + +def _create_spark_dataframes( + pandas_dfs: Iterable[Tuple], +) -> Iterable[Tuple]: + """Creates spark dataframe versions of given native pandas dataframes.""" + + spark_session: pyspark.sql.SparkSession = sslocal() + answer = [] + for tup in pandas_dfs: + dftouse = strip_periodic(tup[0], ts_col_name=tup[1]) + toappend = ( + (spark_session.createDataFrame(dftouse), tup[1], tup[2], tup[3]) + if len(tup) > 2 + else (spark_session.createDataFrame(dftouse), tup[1]) + ) + answer.append(toappend) + return answer + + +def create_extended_test_dfs(testable_pandas_data_frames: List[Tuple]) -> List[Tuple]: + """Extend (or not) the input list of native pandas dataframes with their spark datafram equivalents. + Allow picking and choosing via an environment variable setting for DFTYPE.""" + answer = copy.copy(testable_pandas_data_frames) + DFTYPE = os.getenv("DFTYPE", None) + try: + # only create the spark dataframes if needed + if DFTYPE is None or DFTYPE == "spark_all": + testable_spark_dataframes = _create_spark_dataframes( + testable_pandas_data_frames + ) + + if DFTYPE is None: + answer.extend(testable_spark_dataframes) + elif DFTYPE == "pandas_all": + ... # no op + elif DFTYPE == "spark_all": + answer = testable_spark_dataframes + # elif DFTYPE.startswith("pandas_"): + # answer = [answer[int(DFTYPE.split("_")[1])]] + # elif DFTYPE.startswith("spark_"): + # answer = [testable_spark_dataframes[int(DFTYPE.split("_")[1])]] + else: + raise Exception(f"invalid setting {DFTYPE} for DFTYPE") + return answer + except IndexError as ie: + print(ie) + return testable_pandas_data_frames + testable_spark_dataframes + + +def key_helper(df, baskey): + """This might not be necessary any more. Its intent was to enforce string + column names for spark dataframes when we were allowing integer column names + for the pandas backend implementation. It's kept for legacy purposes in the test + for the time being.""" + return baskey if baskey is None or isinstance(df, pd.DataFrame) else str(baskey) + + +def df_project(df): + """Return pandas api on the fly when needed by tests.""" + return df.pandas_api() if isinstance(df, pyspark.sql.DataFrame) else df + + +def df_col_to_nparray(df, col): + if isinstance(df, pyspark.sql.DataFrame): + return np.array([row[col] for row in df.collect()]) + + return np.array(df[col].to_numpy()) diff --git a/tests/interfaces/ts/helpers.py b/tests/interfaces/ts/helpers.py new file mode 100644 index 000000000..849582a26 --- /dev/null +++ b/tests/interfaces/ts/helpers.py @@ -0,0 +1,126 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common test helpers +""" + +# Standard +from typing import Union +import warnings + +# Third Party +from pyspark.sql import Row, SparkSession +import numpy as np +import pandas as pd +import pyspark +import pytest + +# First Party +import alog + +# Local +from caikit.interfaces.ts.data_model.toolkit import optional_dependencies +from caikit.interfaces.ts.data_model.toolkit.sparkconf import sparkconf_local +import caikit.interfaces.ts.data_model as dm + +warnings.filterwarnings("ignore", category=ResourceWarning) + +## Global Config ############################################################### + +test_log = alog.use_channel("TEST") + +## Test Data ################################################################### + +sample_data = {"key": [1, 2, 3], "val": [4, 5, 6], "val2": [7.1, 8.1, 9.1]} +sample_df = pd.DataFrame(sample_data) +sample_np = np.array( + [sample_data["key"], sample_data["val"], sample_data["val2"]] +).transpose() +sample_np_univariate = sample_np[:, [0, 1]] +sample_ts = list(zip(sample_data["key"], sample_data["val"])) +sample_mvts = list( + zip( + sample_data["key"], + ((v1, v2) for v1, v2 in zip(sample_data["val"], sample_data["val2"])), + ) +) + + +## Helpers ##################################################################### + + +@pytest.fixture(scope="session") +def sslocal_fixture(): + spark_session = SparkSession.builder.config(conf=sparkconf_local()).getOrCreate() + yield spark_session + spark_session.stop() + + +def sslocal(): + return SparkSession.builder.config(conf=sparkconf_local()).getOrCreate() + + +@pytest.fixture(scope="session") +def sample_spark_df(sslocal_fixture): + """Pytest fixture for a self-enclosed spark data frame""" + spark = sslocal_fixture + + sample_spark_df_ = spark.createDataFrame( + [ + Row(**{key: val[idx] for key, val in sample_data.items()}) + for idx in range(sample_data["key"]) + ] + ) + yield sample_spark_df_ + + +@pytest.fixture(scope="session") +def sample_spark_df_univariate(sample_spark_df): + """Pytest fixture for a self-enclosed spark data frame""" + return sample_spark_df.select(["key", "val"]) + + +@pytest.fixture +def no_pandas(): + """Fixture to simulate running without pandas installed""" + current = optional_dependencies.HAVE_PANDAS + HAVE_PANDAS = False + yield + HAVE_PANDAS = current + + +@pytest.fixture +def no_spark(): + """Fixture to simulate running without pyspark installed""" + current = optional_dependencies.HAVE_PYSPARK + optional_dependencies.HAVE_PYSPARK = False + yield + optional_dependencies.HAVE_PYSPARK = current + + +## other helpers + + +def get_anytimeseries_length( + X: Union[pd.DataFrame, dm.TimeSeries, pyspark.sql.DataFrame] +) -> Union[None, int]: + """Get the the length of any AnyTimeSeries object""" + if isinstance(X, pd.DataFrame): + return len(X) + elif isinstance(X, dm.TimeSeries): + return len(X) + elif isinstance(X, pyspark.sql.DataFrame): + return X.count() + else: + raise ValueError("Unknown time series type provided") diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py index 2b014e091..f70947f14 100644 --- a/tests/runtime/conftest.py +++ b/tests/runtime/conftest.py @@ -14,6 +14,7 @@ import tempfile import threading import time +import warnings # Third Party from grpc_health.v1 import health_pb2, health_pb2_grpc @@ -35,6 +36,8 @@ from caikit.runtime.service_generation.rpcs import TaskPredictRPC from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer +from caikit.runtime.servicers.model_runtime_servicer import ModelRuntimeServicerImpl +from caikit.runtime.work_management.abortable_context import ThreadInterrupter from tests.conftest import random_test_id, temp_config from tests.fixtures import Fixtures @@ -68,10 +71,9 @@ def http_session_scoped_open_port(): return _open_port() -def _open_port(): +def _open_port(start=8888): # TODO: This has obvious problems where the port returned for use by a test is not immediately # put into use, so parallel tests could attempt to use the same port. - start = 8888 end = start + 1000 host = "localhost" for port in range(start, end): @@ -98,13 +100,18 @@ def sample_inference_service(render_protos) -> ServicePackage: @pytest.fixture(scope="session") def sample_predict_servicer(sample_inference_service) -> GlobalPredictServicer: - servicer = GlobalPredictServicer(inference_service=sample_inference_service) + interrupter = ThreadInterrupter() + interrupter.start() + servicer = GlobalPredictServicer( + inference_service=sample_inference_service, interrupter=interrupter + ) yield servicer # Make sure to not leave the rpc_meter hanging # (It does try to clean itself up on destruction, but just to be sure) rpc_meter = getattr(servicer, "rpc_meter", None) if rpc_meter: rpc_meter.end_writer_thread() + interrupter.stop() @pytest.fixture(scope="session") @@ -151,9 +158,7 @@ def runtime_grpc_test_server(open_port, *args, **kwargs): @pytest.fixture(scope="session") -def runtime_grpc_server( - session_scoped_open_port, -) -> RuntimeGRPCServer: +def runtime_grpc_server(session_scoped_open_port) -> RuntimeGRPCServer: with runtime_grpc_test_server( session_scoped_open_port, ) as server: @@ -161,6 +166,12 @@ def runtime_grpc_server( yield server +@pytest.fixture(scope="session") +def model_runtime_servicer(runtime_grpc_server) -> ModelRuntimeServicerImpl: + # Builds a new servicer, the one in the server is a bit hard to access + return ModelRuntimeServicerImpl(interrupter=runtime_grpc_server.interrupter) + + @contextmanager def runtime_http_test_server(open_port, *args, **kwargs): """Helper to wrap creation of RuntimeHTTPServer in temporary configurations""" @@ -179,12 +190,18 @@ def runtime_http_test_server(open_port, *args, **kwargs): }, "merge", ): - config_overrides = {} - if "tls_config_override" in kwargs: - config_overrides = kwargs["tls_config_override"] - kwargs["tls_config_override"] = config_overrides["runtime"]["tls"] + # Forward the special "tls_config_override" to "tls_config_override" + # IFF the configs contain actual TLS (indicated by the presence of + # the special "use_in_test" element). + config_overrides = kwargs.pop("tls_config_override", {}) + if tls_config_override := config_overrides.get("runtime", {}).get("tls"): + kwargs["tls_config_override"] = tls_config_override + else: + config_overrides = {} + check_readiness = kwargs.pop("check_readiness", True) with http_server.RuntimeHTTPServer(*args, **kwargs) as server: - _check_http_server_readiness(server, config_overrides) + if check_readiness: + _check_http_server_readiness(server, config_overrides) # Give tests access to the workdir server.workdir = workdir yield server @@ -267,6 +284,23 @@ def file_task_model_id(box_model_path) -> str: model_manager.unload_model(model_id) +@pytest.fixture +def primitive_task_model_id(primitive_model_path) -> str: + """Loaded model ID using model manager load model implementation""" + model_id = random_test_id() + model_manager = ModelManager.get_instance() + # model load test already tests with archive - just using a model path here + model_manager.load_model( + model_id, + local_model_path=primitive_model_path, + model_type=Fixtures.get_good_model_type(), # eventually we'd like to be determining the type from the model itself... + ) + yield model_id + + # teardown + model_manager.unload_model(model_id) + + @pytest.fixture def sample_task_unary_rpc(sample_inference_service: ServicePackage) -> TaskPredictRPC: return sample_inference_service.caikit_rpcs["SampleTaskPredict"] @@ -383,7 +417,9 @@ def _kill_proc(self): self.proc.kill() def __enter__(self): - self.proc = subprocess.Popen(self._cmd, env=self._env) + self.proc = subprocess.Popen( + self._cmd, env=self._env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) self._kill_timer.start() return self.proc @@ -417,12 +453,10 @@ def _check_server_readiness(server): def _check_http_server_readiness(server, config_overrides: Dict[str, Dict]): mode = "http" - verify = None cert = None # tls if config_overrides: mode = "https" - verify = config_overrides["use_in_test"]["ca_cert"] # mtls if "client_cert" and "client_key" in config_overrides["use_in_test"]: cert = ( @@ -432,11 +466,13 @@ def _check_http_server_readiness(server, config_overrides: Dict[str, Dict]): done = False while not done: try: - response = requests.get( - f"{mode}://localhost:{server.port}{http_server.HEALTH_ENDPOINT}", - verify=verify, - cert=cert, - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", module="urllib3") + response = requests.get( + f"{mode}://localhost:{server.port}{http_server.HEALTH_ENDPOINT}", + verify=False, + cert=cert, + ) assert response.status_code == 200 assert response.text == "OK" done = True diff --git a/tests/runtime/http_server/test_http_server.py b/tests/runtime/http_server/test_http_server.py index 4ec36865c..27d53e11e 100644 --- a/tests/runtime/http_server/test_http_server.py +++ b/tests/runtime/http_server/test_http_server.py @@ -16,10 +16,10 @@ """ # Standard from contextlib import contextmanager -from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import Dict +from typing import Dict, List, Optional +import base64 import json import os import signal @@ -28,20 +28,34 @@ # Third Party from fastapi.testclient import TestClient -import numpy as np import pytest import requests import tls_test_tools # Local from caikit.core import MODEL_MANAGER, DataObjectBase, dataobject +from caikit.core.data_model import TrainingStatus +from caikit.core.model_management.multi_model_finder import MultiModelFinder from caikit.runtime import http_server +from caikit.runtime.http_server.http_server import StreamEventTypes from tests.conftest import temp_config from tests.runtime.conftest import ( ModuleSubproc, register_trained_model, runtime_http_test_server, ) +from tests.runtime.model_management.test_model_manager import ( + non_singleton_model_managers, +) + +## Fixtures ##################################################################### + + +@pytest.fixture +def client(runtime_http_server) -> TestClient: + with TestClient(runtime_http_server.app) as client: + yield client + ## Helpers ##################################################################### @@ -77,19 +91,24 @@ def generate_tls_configs( tls: bool = False, mtls: bool = False, inline: bool = False, + separate_client_ca: bool = False, + server_sans: Optional[List[str]] = None, + client_sans: Optional[List[str]] = None, **http_config_overrides, ) -> Dict[str, Dict]: """Helper to generate tls configs""" with tempfile.TemporaryDirectory() as workdir: config_overrides = {} - client_keyfile, client_certfile, ca_certfile = None, None, None + client_keyfile, client_certfile = None, None ca_cert, server_cert, server_key = None, None, None + use_in_test = config_overrides.setdefault("use_in_test", {}) + use_in_test["workdir"] = workdir if mtls or tls: ca_key = tls_test_tools.generate_key()[0] ca_cert = tls_test_tools.generate_ca_cert(ca_key) - ca_certfile, _ = save_key_cert_pair("ca", workdir, cert=ca_cert) server_key, server_cert = tls_test_tools.generate_derived_key_cert_pair( - ca_key=ca_key + ca_key=ca_key, + san_list=server_sans, ) server_certfile, server_keyfile = save_key_cert_pair( "server", workdir, server_key, server_cert @@ -98,15 +117,20 @@ def generate_tls_configs( if inline: tls_config = TLSConfig( server=KeyPair(cert=server_cert, key=server_key), - client=KeyPair(cert=ca_cert if mtls else "", key=""), + client=KeyPair(cert="", key=""), ) else: tls_config = TLSConfig( server=KeyPair(cert=server_certfile, key=server_keyfile), - client=KeyPair(cert=ca_certfile if mtls else "", key=""), + client=KeyPair(cert="", key=""), ) - # need to save this ca_certfile in config_overrides so the tls tests below can access it from client side - config_overrides["use_in_test"] = {"ca_cert": ca_certfile} + + # need to save this ca_certfile in config_overrides so the tls + # tests below can access it from client side + ca_certfile, _ = save_key_cert_pair("ca", workdir, cert=ca_cert) + use_in_test["ca_cert"] = ca_certfile + use_in_test["server_key"] = server_keyfile + use_in_test["server_cert"] = server_certfile # also saving a bad ca_certfile for a failure test case bad_ca_file = os.path.join(workdir, "bad_ca_cert.crt") @@ -115,17 +139,42 @@ def generate_tls_configs( "-----BEGIN CERTIFICATE-----\nfoobar\n-----END CERTIFICATE-----" ) handle.write(bad_cert) - config_overrides["use_in_test"]["bad_ca_cert"] = bad_ca_file + use_in_test["bad_ca_cert"] = bad_ca_file if mtls: + if separate_client_ca: + subject_kwargs = {"common_name": "my.client"} + client_ca_key = tls_test_tools.generate_key()[0] + client_ca_cert = tls_test_tools.generate_ca_cert( + client_ca_key, **subject_kwargs + ) + else: + subject_kwargs = {} + client_ca_key = ca_key + client_ca_cert = ca_cert + + # If inlining the client CA + if inline: + tls_config.client.cert = client_ca_cert + else: + client_ca_certfile, _ = save_key_cert_pair( + "client_ca", workdir, cert=client_ca_cert + ) + tls_config.client.cert = client_ca_certfile + + # Set up the client key/cert pair derived from the client CA client_certfile, client_keyfile = save_key_cert_pair( "client", workdir, - *tls_test_tools.generate_derived_key_cert_pair(ca_key=ca_key), + *tls_test_tools.generate_derived_key_cert_pair( + ca_key=client_ca_key, + san_list=client_sans, + **subject_kwargs, + ), ) # need to save the client cert and key in config_overrides so the mtls test below can access it - config_overrides["use_in_test"]["client_cert"] = client_certfile - config_overrides["use_in_test"]["client_key"] = client_keyfile + use_in_test["client_cert"] = client_certfile + use_in_test["client_key"] = client_keyfile config_overrides["runtime"] = {"tls": tls_config.to_dict()} config_overrides.setdefault("runtime", {})["http"] = { @@ -149,9 +198,7 @@ def test_insecure_server(runtime_http_server, open_port): def test_basic_tls_server(open_port): - with generate_tls_configs( - open_port, tls=True, mtls=False, http_config_overrides={} - ) as config_overrides: + with generate_tls_configs(open_port, tls=True, mtls=False) as config_overrides: with runtime_http_test_server( open_port, tls_config_override=config_overrides, @@ -165,9 +212,7 @@ def test_basic_tls_server(open_port): def test_basic_tls_server_with_wrong_cert(open_port): - with generate_tls_configs( - open_port, tls=True, mtls=False, http_config_overrides={} - ) as config_overrides: + with generate_tls_configs(open_port, tls=True, mtls=False) as config_overrides: with runtime_http_test_server( open_port, tls_config_override=config_overrides, @@ -181,14 +226,36 @@ def test_basic_tls_server_with_wrong_cert(open_port): def test_mutual_tls_server(open_port): + with generate_tls_configs(open_port, tls=True, mtls=True) as config_overrides: + with runtime_http_test_server( + open_port, + tls_config_override=config_overrides, + ) as http_server_with_mtls: + # start a non-blocking http server with mutual tls + resp = requests.get( + f"https://localhost:{http_server_with_mtls.port}/docs", + verify=config_overrides["use_in_test"]["ca_cert"], + cert=( + config_overrides["use_in_test"]["client_cert"], + config_overrides["use_in_test"]["client_key"], + ), + ) + resp.raise_for_status() + + +def test_mutual_tls_server_different_client_ca(open_port): with generate_tls_configs( - open_port, tls=True, mtls=True, http_config_overrides={} + open_port, + tls=True, + mtls=True, + separate_client_ca=True, ) as config_overrides: + # start a non-blocking http server with mutual tls with runtime_http_test_server( open_port, tls_config_override=config_overrides, ) as http_server_with_mtls: - # start a non-blocking http server with mutual tls + # Make a request with the client's key/cert pair resp = requests.get( f"https://localhost:{http_server_with_mtls.port}/docs", verify=config_overrides["use_in_test"]["ca_cert"], @@ -205,7 +272,7 @@ def test_mutual_tls_server_inline(open_port): than with files """ with generate_tls_configs( - open_port, tls=True, mtls=True, inline=True, http_config_overrides={} + open_port, tls=True, mtls=True, inline=True ) as config_overrides: with runtime_http_test_server( open_port, @@ -224,9 +291,7 @@ def test_mutual_tls_server_inline(open_port): def test_mutual_tls_server_with_wrong_cert(open_port): - with generate_tls_configs( - open_port, tls=True, mtls=True, http_config_overrides={} - ) as config_overrides: + with generate_tls_configs(open_port, tls=True, mtls=True) as config_overrides: with runtime_http_test_server( open_port, tls_config_override=config_overrides, @@ -280,11 +345,10 @@ def test_services_disabled(open_port, enabled_services): ## Inference Tests ####################################################################### -def test_docs(runtime_http_server): +def test_docs(client): """Simple check that pinging /docs returns 200""" - with TestClient(runtime_http_server.app) as client: - response = client.get("/docs") - assert response.status_code == 200 + response = client.get("/docs") + assert response.status_code == 200 def test_docs_using_running_http_server(runtime_http_server): @@ -294,150 +358,198 @@ def test_docs_using_running_http_server(runtime_http_server): assert response.status_code == 200 -def test_inference_sample_task(sample_task_model_id, runtime_http_server): - """Simple check that we can ping a model""" - with TestClient(runtime_http_server.app) as client: - json_input = {"inputs": {"name": "world"}, "model_id": sample_task_model_id} - response = client.post( - f"/api/v1/task/sample", - json=json_input, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - print(json_response) - assert response.status_code == 200, json_response - assert json_response["greeting"] == "Hello world" +def test_docs_with_models( + runtime_http_server, sample_task_model_id, primitive_task_model_id +): + """Simple check that pinging /docs still returns 200 when models have been + loaded""" + response = requests.get(f"http://localhost:{runtime_http_server.port}/docs") + assert response.status_code == 200 -def test_inference_sample_task_optional_field( - sample_task_model_id, runtime_http_server -): - """Simple check for optional fields""" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_id": sample_task_model_id, - "inputs": {"name": "world"}, - "parameters": {"throw": True}, - } - response = client.post( - f"/api/v1/task/sample", - json=json_input, - ) - # this is 500 because we explicitly pass in `throw` as True, which - # raises an internal error in the module - assert response.status_code == 500 +def test_inference_sample_task(sample_task_model_id, client): + """Simple check that we can ping a model""" + json_input = {"inputs": {"name": "world"}, "model_id": sample_task_model_id} + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" -def test_inference_sample_task_multipart_input( - sample_task_model_id, runtime_http_server -): +def test_inference_primitive_task(primitive_task_model_id, client): + """Simple check that we can ping a model""" + json_input = { + "inputs": {"name": "hello"}, + "parameters": { + "bool_type": True, + "int_type": 1, + "float_type": 1.0, + "str_type": "astring", + "bytes_type": "cmF3Ynl0ZXMK", + "list_type": ["list", "of", "strings"], + }, + "model_id": primitive_task_model_id, + } + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert "hello: primitives!" in json_response["greeting"] + + +def test_inference_sample_task_optional_field(sample_task_model_id, client): + """Simple check for optional fields""" + json_input = { + "model_id": sample_task_model_id, + "inputs": {"name": "world"}, + "parameters": {"throw": True}, + } + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + # this is 500 because we explicitly pass in `throw` as True, which + # raises an internal error in the module + assert response.status_code == 500 + + +def test_inference_sample_task_multipart_input(sample_task_model_id, client): """Simple check that we can submit multipart requests""" - with TestClient(runtime_http_server.app) as client: - multipart_input = { - "model_id": sample_task_model_id, - "inputs.name": "world", - "parameters": json.dumps({"throw": False}), - } + multipart_input = { + "model_id": sample_task_model_id, + "inputs.name": "world", + "parameters": json.dumps({"throw": False}), + } - response = client.post(f"/api/v1/task/sample", files=multipart_input) + response = client.post(f"/api/v1/task/sample", files=multipart_input) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["greeting"] == "Hello world" + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" - multipart_input["parameters"] = json.dumps({"throw": True}) - response = client.post( - f"/api/v1/task/sample", - files=multipart_input, - ) - # this is 500 because we explicitly pass in `throw` as True, which - # raises an internal error in the module - assert response.status_code == 500 + multipart_input["parameters"] = json.dumps({"throw": True}) + response = client.post( + f"/api/v1/task/sample", + files=multipart_input, + ) + # this is 500 because we explicitly pass in `throw` as True, which + # raises an internal error in the module + assert response.status_code == 500 -def test_inference_file_task_multipart_flipped_input( - file_task_model_id, runtime_http_server -): +def test_inference_file_task_multipart_flipped_input(file_task_model_id, client): """Ensure that multiple multipart json inputs are merged together instead of overriding""" - with TestClient(runtime_http_server.app) as client: - # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded - temp_file = tempfile.NamedTemporaryFile() - temp_file_name = Path(temp_file.name).name - temp_file.write(b"pdfdata\xff\xff\x00") - temp_file.flush() - temp_file.seek(0) - - file_input = { - "model_id": file_task_model_id, - "inputs.file": temp_file, - "inputs": json.dumps({"metadata": {"name": "agoodname"}}), - } + # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded + temp_file = tempfile.NamedTemporaryFile() + temp_file_name = Path(temp_file.name).name + temp_file.write(b"pdfdata\xff\xff\x00") + temp_file.flush() + temp_file.seek(0) + + file_input = { + "model_id": file_task_model_id, + "inputs.file": temp_file, + "inputs": json.dumps({"metadata": {"name": "agoodname"}}), + } + + response = client.post( + f"/api/v1/task/file", + files=file_input, + ) + content_stream = BytesIO(response.content) - response = client.post( - f"/api/v1/task/file", - files=file_input, - ) - content_stream = BytesIO(response.content) - - assert response.status_code == 200 - with zipfile.ZipFile(content_stream) as output_file: - assert len(output_file.namelist()) == 2 - assert "metadata.json" in output_file.namelist() - assert f"processed_{temp_file_name}" in output_file.namelist() + assert response.status_code == 200 + with zipfile.ZipFile(content_stream) as output_file: + assert len(output_file.namelist()) == 2 + assert "metadata.json" in output_file.namelist() + assert f"processed_{temp_file_name}" in output_file.namelist() - with output_file.open(f"processed_{temp_file_name}") as pdf_result: - assert pdf_result.read() == b"bounding|pdfdata\xff\xff\x00|box" + with output_file.open(f"processed_{temp_file_name}") as pdf_result: + assert pdf_result.read() == b"bounding|pdfdata\xff\xff\x00|box" -def test_inference_other_task(other_task_model_id, runtime_http_server): +def test_inference_other_task(other_task_model_id, client): """Simple check that we can ping a model""" - with TestClient(runtime_http_server.app) as client: - json_input = {"model_id": other_task_model_id, "inputs": {"name": "world"}} - response = client.post( - f"/api/v1/task/other", - json=json_input, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["farewell"] == "goodbye: world 42 times" + json_input = {"model_id": other_task_model_id, "inputs": {"name": "world"}} + response = client.post( + f"/api/v1/task/other", + json=json_input, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["farewell"] == "goodbye: world 42 times" -def test_output_file_task(file_task_model_id, runtime_http_server): +def test_output_file_task(file_task_model_id, client): """Simple check that we can get a file output""" - with TestClient(runtime_http_server.app) as client: - # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded - temp_file = tempfile.NamedTemporaryFile() - temp_file_name = Path(temp_file.name).name - temp_file.write(b"pdfdata\xff\xff\x00") - temp_file.flush() - temp_file.seek(0) - - file_input = { - "model_id": file_task_model_id, - "inputs.file": temp_file, - "inputs.metadata": json.dumps({"name": "agoodname"}), - } + # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded + temp_file = tempfile.NamedTemporaryFile() + temp_file_name = Path(temp_file.name).name + temp_file.write(b"pdfdata\xff\xff\x00") + temp_file.flush() + temp_file.seek(0) + + file_input = { + "model_id": file_task_model_id, + "inputs.file": temp_file, + "inputs.metadata": json.dumps({"name": "agoodname"}), + } + + response = client.post( + f"/api/v1/task/file", + files=file_input, + ) + content_stream = BytesIO(response.content) - response = client.post( - f"/api/v1/task/file", - files=file_input, - ) - content_stream = BytesIO(response.content) - - assert response.status_code == 200 - with zipfile.ZipFile(content_stream) as output_file: - assert len(output_file.namelist()) == 2 - assert "metadata.json" in output_file.namelist() - assert f"processed_{temp_file_name}" in output_file.namelist() + assert response.status_code == 200 + with zipfile.ZipFile(content_stream) as output_file: + assert len(output_file.namelist()) == 2 + assert "metadata.json" in output_file.namelist() + assert f"processed_{temp_file_name}" in output_file.namelist() + + with output_file.open(f"processed_{temp_file_name}") as pdf_result: + assert pdf_result.read() == b"bounding|pdfdata\xff\xff\x00|box" + + +def test_invalid_input_exception(file_task_model_id, client): + """Simple check that the server catches caikit core exceptions""" + json_file_input = { + "model_id": file_task_model_id, + "inputs": { + "file": { + # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded + "data": "cGRmZGF0Yf//AA==", + "filename": "unsupported_file.exe", + }, + "metadata": { + "name": "agoodname", + }, + }, + } - with output_file.open(f"processed_{temp_file_name}") as pdf_result: - assert pdf_result.read() == b"bounding|pdfdata\xff\xff\x00|box" + response = client.post( + f"/api/v1/task/file", + json=json_file_input, + ) + assert response.status_code == 400 + json_response = json.loads(response.content.decode(response.default_encoding)) + assert json_response["details"] == "Executables are not a supported File type" -def test_inference_streaming_sample_module(sample_task_model_id, runtime_http_server): +@pytest.mark.skip( + "Skipping testing streaming cases with FastAPI's testclient, pending resolution https://github.com/tiangolo/fastapi/discussions/10518" +) +def test_inference_streaming_sample_module(sample_task_model_id, client): """Simple check for testing a happy path unary-stream case""" - with TestClient(runtime_http_server.app) as client: - json_input = {"model_id": sample_task_model_id, "inputs": {"name": "world"}} + json_input = {"model_id": sample_task_model_id, "inputs": {"name": "world"}} + # send in multiple requests just to check + for i in range(10): stream = client.post( f"/api/v1/task/server-streaming-sample", json=json_input, @@ -446,102 +558,354 @@ def test_inference_streaming_sample_module(sample_task_model_id, runtime_http_se stream_content = stream.content.decode(stream.default_encoding) stream_responses = json.loads( "[{}]".format( - stream_content.replace("data: ", "") + stream_content.replace("event: ", '{"event":') + .replace( + StreamEventTypes.MESSAGE.value, + '"' + f"{StreamEventTypes.MESSAGE.value}" + '"}', + ) + .replace("data: ", "") .replace("\r\n", "") .replace("}{", "}, {") ) ) - assert len(stream_responses) == 10 + assert len(stream_responses) == 20 assert all( - resp.get("greeting") == "Hello world stream" for resp in stream_responses + resp.get("greeting") == "Hello world stream" + for resp in stream_responses + if "greeting" in resp ) - - -def test_no_model_id(runtime_http_server): - """Simple check that we can ping a model""" - with TestClient(runtime_http_server.app) as client: - response = client.post( - f"/api/v1/task/sample", - json={"inputs": {"name": "world"}}, - ) - assert response.status_code == 400 - "Please provide model_id in payload" in response.content.decode( - response.default_encoding + assert all( + resp.get("event") == StreamEventTypes.MESSAGE.value + for resp in stream_responses + if "event" in resp ) -def test_inference_multi_task_module(multi_task_model_id, runtime_http_server): - """Simple check that we can ping a model""" - with TestClient(runtime_http_server.app) as client: - # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded - json_input = { - "model_id": multi_task_model_id, - "inputs": {"filename": "example.pdf", "data": "cGRmZGF0Yf//AA=="}, - } - response = client.post( - f"/api/v1/task/second", - json=json_input, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["farewell"] == "Goodbye from SecondTask" - +def test_inference_streaming_sample_module_actual_server( + sample_task_model_id, runtime_http_server +): + """Simple check for testing a happy path unary-stream case + but pings the actual running server""" -def test_model_not_found(runtime_http_server): - """Simple check that we can ping a model""" - with TestClient(runtime_http_server.app) as client: - response = client.post( - f"/api/v1/task/sample", - json={"model_id": "not_an_id", "inputs": {"name": "world"}}, + for i in range(10): + input = {"model_id": sample_task_model_id, "inputs": {"name": f"world{i}"}} + url = f"http://localhost:{runtime_http_server.port}/api/v1/task/server-streaming-sample" + stream = requests.post(url=url, json=input, verify=False) + assert stream.status_code == 200 + stream_content = stream.content.decode(stream.encoding) + stream_responses = json.loads( + "[{}]".format( + stream_content.replace("event: ", '{"event":') + .replace( + StreamEventTypes.MESSAGE.value, + '"' + f"{StreamEventTypes.MESSAGE.value}" + '"}', + ) + .replace("data: ", "") + .replace("\r\n", "") + .replace("}{", "}, {") + ) + ) + assert len(stream_responses) == 20 + assert all( + resp.get("greeting") == f"Hello world{i} stream" + for resp in stream_responses + if "greeting" in resp + ) + assert all( + resp.get("event") == StreamEventTypes.MESSAGE.value + for resp in stream_responses + if "event" in resp ) - assert response.status_code == 404 -def test_inference_sample_task_incorrect_input( +def test_inference_streaming_sample_module_actual_server_throws( sample_task_model_id, runtime_http_server ): - """Test that with an incorrect input, the test doesn't throw but - instead returns None""" - with TestClient(runtime_http_server.app) as client: - json_input = { + """Simple check for testing an exception in unary-stream case + that pings the actual running server""" + + for i in range(10): + input = { "model_id": sample_task_model_id, - "inputs": {"blah": "world"}, + "inputs": {"name": f"world{i}"}, + "parameters": {"err_stream": True}, } - response = client.post( - f"/api/v1/task/sample", - json=json_input, + url = f"http://localhost:{runtime_http_server.port}/api/v1/task/server-streaming-sample" + stream = requests.post(url=url, json=input, verify=False) + assert stream.status_code == 200 + stream_content = stream.content.decode(stream.encoding) + stream_responses = json.loads( + "[{}]".format( + stream_content.replace("event: ", '{"event":') + .replace( + StreamEventTypes.ERROR.value, + '"' + f"{StreamEventTypes.ERROR.value}" + '"}', + ) + .replace("data: ", "") + .replace("\r\n", "") + .replace("}{", "}, {") + ) ) - assert response.status_code == 422, response.content.decode( - response.default_encoding + assert len(stream_responses) == 2 + assert stream_responses[0].get("event") == StreamEventTypes.ERROR.value + assert ( + stream_responses[1].get("details") == "ValueError('raising a ValueError')" ) + assert stream_responses[1].get("code") == 400 + + +def test_inference_malformed_param(client): + """Send a malformed data parameter field to the inference call to induce the correct HTTP error""" + + response = client.post( + "/api/v1/task/sample", + data='{"bad_input": 100,}', # send intentionally bad json + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + json_response = json.loads(response.content.decode(response.default_encoding)) + + assert "Invalid JSON" in json_response["details"] + assert json_response["additional_info"][0]["type"] == "json_invalid" + + +def test_inference_non_serializable_json(client): + """Send non_serializable json as the data parameter field to the inference call to test correct error handling""" + + byte_data = bytes([1, 2, 3, 4, 5]) + base64_data = base64.b64encode(byte_data) + + response = client.post( + "/api/v1/task/sample", + data=base64_data, # send byte object + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + json_response = json.loads(response.content.decode(response.default_encoding)) + + assert "Invalid JSON" in json_response["details"] + assert json_response["additional_info"][0]["type"] == "json_invalid" + + +def test_no_model_id(client): + """Simple check to make sure we return a 400 if no model_id in payload""" + response = client.post( + f"/api/v1/task/sample", + json={"inputs": {"name": "world"}}, + ) + assert response.status_code == 400 + "Please provide model_id in payload" in response.content.decode( + response.default_encoding + ) + + +def test_inference_multi_task_module(multi_task_model_id, client): + """Simple check that we can ping a model""" + # cGRmZGF0Yf//AA== is b"pdfdata\xff\xff\x00" base64 encoded + json_input = { + "model_id": multi_task_model_id, + "inputs": {"filename": "example.pdf", "data": "cGRmZGF0Yf//AA=="}, + } + response = client.post( + f"/api/v1/task/second", + json=json_input, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["farewell"] == "Goodbye from SecondTask" + + +def test_model_not_found(client): + """Simple error check to make sure we return a 404 in case of + incorrect model_id""" + response = client.post( + f"/api/v1/task/sample", + json={"model_id": "not_an_id", "inputs": {"name": "world"}}, + ) + assert response.status_code == 404 + + +def test_model_not_found_with_lazy_load_multi_model_finder(open_port): + """An error check to make sure we return a 404 in case of + incorrect model_id while using multi model finder with lazy load enabled""" + with tempfile.TemporaryDirectory() as workdir: + # NOTE: This test requires that the ModelManager class not be a singleton. + # To accomplish this, the singleton instance is temporarily removed. + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": workdir, + "lazy_load_local_models": True, + }, + "model_management": { + "finders": { + "default": { + "type": MultiModelFinder.name, + "config": { + "finder_priority": ["local"], + }, + }, + "local": {"type": "LOCAL"}, + } + }, + }, + "merge", + ): + with runtime_http_test_server(open_port) as server: + # double checking that our local model_management change took affect + assert ( + server.global_predict_servicer._model_manager._lazy_load_local_models + ) + response = requests.post( + f"http://localhost:{server.port}/api/v1/task/sample", + json={"model_id": "not_an_id", "inputs": {"name": "world"}}, + ) + assert response.status_code == 404 + + +def test_inference_sample_task_incorrect_input(sample_task_model_id, client): + """Test that with an incorrect input, we get back a 422""" + json_input = { + "model_id": sample_task_model_id, + "inputs": {"blah": "world"}, + } + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + assert response.status_code == 422 + json_response = json.loads(response.content.decode(response.default_encoding)) + # assert standard fields in the response + assert json_response["details"] is not None + assert json_response["code"] is not None + assert json_response["id"] is not None + assert json_response["details"] == "Extra inputs are not permitted" @pytest.mark.skip("Skipping since we're not tacking forward compatibility atm") -def test_inference_sample_task_forward_compatibility( - sample_task_model_id, runtime_http_server -): +def test_inference_sample_task_forward_compatibility(sample_task_model_id, client): """Test that clients can send in params that don't exist on server without any error""" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_id": sample_task_model_id, - "inputs": {"name": "world", "blah": "blah"}, - } - response = client.post( - f"/api/v1/task/sample", - json=json_input, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["greeting"] == "Hello world" + json_input = { + "model_id": sample_task_model_id, + "inputs": {"name": "world", "blah": "blah"}, + } + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" + + +def test_health_check_ok(client): + """Make sure the health check returns OK""" + response = client.get(http_server.HEALTH_ENDPOINT) + assert response.status_code == 200 + assert response.text == "OK" -def test_health_check_ok(runtime_http_server): - """Make sure the health check returns OK""" +def test_runtime_info_ok(runtime_http_server): + """Make sure the runtime info returns version data""" with TestClient(runtime_http_server.app) as client: - response = client.get(http_server.HEALTH_ENDPOINT) + response = client.get(http_server.RUNTIME_INFO_ENDPOINT) assert response.status_code == 200 - assert response.text == "OK" + + json_response = json.loads(response.content.decode(response.default_encoding)) + assert "caikit" in json_response["python_packages"] + # runtime_version not added if not set + assert json_response["runtime_version"] == "" + # dependent libraries not added if all packages not set to true + assert "py_to_proto" not in json_response["python_packages"] + + +def test_runtime_info_ok_response_all_packages(runtime_http_server): + with temp_config( + { + "runtime": { + "version_info": { + "python_packages": { + "all": True, + }, + "runtime_image": "1.2.3", + } + }, + }, + "merge", + ): + with TestClient(runtime_http_server.app) as client: + response = client.get(http_server.RUNTIME_INFO_ENDPOINT) + assert response.status_code == 200 + + json_response = json.loads( + response.content.decode(response.default_encoding) + ) + assert json_response["runtime_version"] == "1.2.3" + assert "caikit" in json_response["python_packages"] + # dependent libraries versions added + assert "alog" in json_response["python_packages"] + assert "py_to_proto" in json_response["python_packages"] + + +def test_runtime_info_ok_custom_python_packages(runtime_http_server): + """Make sure the runtime info returns version data""" + with temp_config( + {"runtime": {"version_info": {"python_packages": {"custom_package": "0.1.0"}}}}, + merge_strategy="merge", + ): + with TestClient(runtime_http_server.app) as client: + response = client.get(http_server.RUNTIME_INFO_ENDPOINT) + assert response.status_code == 200 + + json_response = json.loads( + response.content.decode(response.default_encoding) + ) + # runtime_version not added if not set + assert json_response["runtime_version"] == "" + # custom library is set while other random packages are not + assert "caikit" in json_response["python_packages"] + assert json_response["python_packages"]["custom_package"] == "0.1.0" + assert "py_to_proto" not in json_response["python_packages"] + + +def test_all_models_info_ok(client, sample_task_model_id): + """Make sure the runtime info returns version data""" + response = client.get(http_server.MODELS_INFO_ENDPOINT) + assert response.status_code == 200 + + json_response = json.loads(response.content.decode(response.default_encoding)) + # Assert some models are loaded + assert len(json_response["models"]) > 0 + + found_sample_task = False + for model in json_response["models"]: + # Assert name and id exist + assert model["name"] and model["module_id"] + if model["name"] == sample_task_model_id: + assert model["module_metadata"]["name"] == "SampleModule" + found_sample_task = True + + assert found_sample_task, "Unable to find sample_task model in models list" + + +def test_single_models_info_ok(client, sample_task_model_id): + """Make sure the runtime info returns version data""" + response = client.get( + http_server.MODELS_INFO_ENDPOINT, params={"model_ids": sample_task_model_id} + ) + assert response.status_code == 200 + + json_response = json.loads(response.content.decode(response.default_encoding)) + # Assert some models are loaded + assert len(json_response["models"]) == 1 + + model = json_response["models"][0] + assert model["name"] == sample_task_model_id + assert model["module_metadata"]["name"] == "SampleModule" def test_http_server_shutdown_with_model_poll(open_port): @@ -581,185 +945,277 @@ def test_http_server_shutdown_with_model_poll(open_port): ## Train Tests ####################################################################### -def test_train_sample_task(runtime_http_server): +def test_train_sample_task(client, runtime_http_server): model_name = "sample_task_train" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_name": model_name, - "parameters": { - "training_data": {"data_stream": {"data": [{"number": 1}]}}, - "batch_size": 42, - }, - } - training_response = client.post( - f"/api/v1/SampleTaskSampleModuleTrain", - json=json_input, - ) - - # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) - assert training_response.status_code == 200, training_json_response - assert (training_id := training_json_response["training_id"]) - assert training_json_response["model_name"] == model_name - - # assert trained model - result = MODEL_MANAGER.get_model_future(training_id).load() - assert result.batch_size == 42 - assert ( - result.MODULE_CLASS - == "sample_lib.modules.sample_task.sample_implementation.SampleModule" - ) - - # register the newly trained model for inferencing - register_trained_model( - runtime_http_server.global_predict_servicer, - model_name, - training_id, - ) - - # test inferencing on new model - json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} - response = client.post( - f"/api/v1/task/sample", - json=json_input_inference, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["greeting"] == "Hello world" - - -def test_train_sample_task_throws_s3_value_error(runtime_http_server): + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, + }, + } + training_response = client.post( + f"/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + + # assert training response + training_json_response = json.loads( + training_response.content.decode(training_response.default_encoding) + ) + assert training_response.status_code == 200, training_json_response + assert (training_id := training_json_response["training_id"]) + assert training_json_response["model_name"] == model_name + + # assert trained model + result = MODEL_MANAGER.get_model_future(training_id).load() + assert result.batch_size == 42 + assert ( + result.MODULE_CLASS + == "sample_lib.modules.sample_task.sample_implementation.SampleModule" + ) + + # register the newly trained model for inferencing + register_trained_model( + runtime_http_server.global_predict_servicer, + model_name, + training_id, + ) + + # test inferencing on new model + json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} + response = client.post( + f"/api/v1/task/sample", + json=json_input_inference, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" + + +def test_train_sample_task_throws_s3_value_error(client): """test that if we provide s3 path, it throws an error""" model_name = "sample_task_train" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_name": model_name, - "output_path": {"path": "non-existent path_to_s3"}, - "parameters": { - "training_data": {"data_stream": {"data": [{"number": 1}]}}, - "batch_size": 42, - }, - } - training_response = client.post( - f"/api/v1/SampleTaskSampleModuleTrain", - json=json_input, - ) - assert ( - "S3 output path not supported by this runtime" - in training_response.content.decode(training_response.default_encoding) - ) - assert training_response.status_code == 500, training_response.content.decode( - training_response.default_encoding - ) - - -def test_train_primitive_task(runtime_http_server): + json_input = { + "model_name": model_name, + "output_path": {"path": "non-existent path_to_s3"}, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, + }, + } + training_response = client.post( + f"/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + assert ( + "S3 output path not supported by this runtime" + in training_response.content.decode(training_response.default_encoding) + ) + assert training_response.status_code == 500, training_response.content.decode( + training_response.default_encoding + ) + + +def test_train_primitive_task(client, runtime_http_server): model_name = "primitive_task_train" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_name": model_name, - "parameters": { - "sample_input": {"name": "test"}, - "simple_list": ["hello", "world"], - "union_list": ["hello", "world"], - "union_list2": ["hello", "world"], - "union_list3": ["hello", "world"], - "union_list4": 1, - "training_params_json_dict_list": [{"foo": {"bar": [1, 2, 3]}}], - "training_params_json_dict": {"foo": {"bar": [1, 2, 3]}}, - "training_params_dict": {"layer_sizes": 100, "window_scaling": 200}, - "training_params_dict_int": {1: 0.1, 2: 0.01}, - }, - } - - training_response = client.post( - f"/api/v1/SampleTaskSamplePrimitiveModuleTrain", - json=json_input, - ) - # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) - assert training_response.status_code == 200, training_json_response - assert (training_id := training_json_response["training_id"]) - assert training_json_response["model_name"] == model_name - - # assert trained model - result = MODEL_MANAGER.get_model_future(training_id).load() - assert result.training_params_dict == { - "layer_sizes": 100, - "window_scaling": 200, - } - assert result.training_params_json_dict == {"foo": {"bar": [1, 2, 3]}} - assert ( - result.MODULE_CLASS - == "sample_lib.modules.sample_task.primitive_party_implementation.SamplePrimitiveModule" - ) - - # register the newly trained model for inferencing - register_trained_model( - runtime_http_server.global_predict_servicer, - model_name, - training_id, - ) - - # test inferencing on new model - json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} - response = client.post( - f"/api/v1/task/sample", - json=json_input_inference, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["greeting"] == "hello: primitives! [1, 2, 3] 100" + json_input = { + "model_name": model_name, + "parameters": { + "sample_input": {"name": "test"}, + "simple_list": ["hello", "world"], + "union_list": ["hello", "world"], + "union_list2": ["hello", "world"], + "union_list3": ["hello", "world"], + "union_list4": 1, + "training_params_json_dict_list": [{"foo": {"bar": [1, 2, 3]}}], + "training_params_json_dict": {"foo": {"bar": [1, 2, 3]}}, + "training_params_dict": {"layer_sizes": 100, "window_scaling": 200}, + "training_params_dict_int": {1: 0.1, 2: 0.01}, + }, + } + + training_response = client.post( + f"/api/v1/SampleTaskSamplePrimitiveModuleTrain", + json=json_input, + ) + # assert training response + training_json_response = json.loads( + training_response.content.decode(training_response.default_encoding) + ) + assert training_response.status_code == 200, training_json_response + assert (training_id := training_json_response["training_id"]) + assert training_json_response["model_name"] == model_name + + # assert trained model + result = MODEL_MANAGER.get_model_future(training_id).load() + assert result.training_params_dict == { + "layer_sizes": 100, + "window_scaling": 200, + } + assert result.training_params_json_dict == {"foo": {"bar": [1, 2, 3]}} + assert ( + result.MODULE_CLASS + == "sample_lib.modules.sample_task.primitive_party_implementation.SamplePrimitiveModule" + ) + + # register the newly trained model for inferencing + register_trained_model( + runtime_http_server.global_predict_servicer, + model_name, + training_id, + ) + + # test inferencing on new model + json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} + response = client.post( + f"/api/v1/task/sample", + json=json_input_inference, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["greeting"] == "hello: primitives! [1, 2, 3] 100" + + +def test_train_other_task(client, runtime_http_server): + model_name = "other_task_train" + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [1, 2]}}, + "sample_input": {"name": "test"}, + }, + } + + training_response = client.post( + f"/api/v1/OtherTaskOtherModuleTrain", + json=json_input, + ) + # assert training response + training_json_response = json.loads( + training_response.content.decode(training_response.default_encoding) + ) + assert training_response.status_code == 200, training_json_response + assert (training_id := training_json_response["training_id"]) + assert training_json_response["model_name"] == model_name + + # assert trained model + result = MODEL_MANAGER.get_model_future(training_id).load() + assert result.batch_size == 64 + assert ( + result.MODULE_CLASS + == "sample_lib.modules.other_task.other_implementation.OtherModule" + ) + + # register the newly trained model for inferencing + register_trained_model( + runtime_http_server.global_predict_servicer, + model_name, + training_id, + ) + + # test inferencing on new model + json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} + response = client.post( + f"/api/v1/task/other", + json=json_input_inference, + ) + json_response = json.loads(response.content.decode(response.default_encoding)) + assert response.status_code == 200, json_response + assert json_response["farewell"] == "goodbye: world 64 times" + + +def test_http_and_grpc_server_share_threadpool( + runtime_http_server, runtime_grpc_server +): + assert runtime_grpc_server.thread_pool is runtime_http_server.thread_pool -def test_train_other_task(runtime_http_server): - model_name = "other_task_train" - with TestClient(runtime_http_server.app) as client: - json_input = { - "model_name": model_name, - "parameters": { - "training_data": {"data_stream": {"data": [1, 2]}}, - "sample_input": {"name": "test"}, - }, - } +def test_train_long_running_sample_task(client, runtime_http_server): + """Test that with a long running training job, the request returns before the training completes""" + model_name = "sample_task_train" + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, + "sleep_time": 5, # mimic long train time + }, + } + training_response = client.post( + f"/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + + # assert training response received before training completed + training_json_response = json.loads( + training_response.content.decode(training_response.default_encoding) + ) + assert training_response.status_code == 200, training_json_response + assert (training_id := training_json_response["training_id"]) + assert training_json_response["model_name"] == model_name + + # assert that the training is still running + model_future = MODEL_MANAGER.get_model_future(training_id) + assert model_future.get_info().status == TrainingStatus.RUNNING + + # Cancel the training + model_future.cancel() + assert model_future.get_info().status == TrainingStatus.CANCELED + assert model_future.get_info().status.is_terminal + + +def test_uvicorn_server_config_valid(): + """Make sure that arbitrary uvicorn configs can be passed through from + runtime.http.server_config + """ + timeout_keep_alive = 10 + with temp_config( + { + "runtime": { + "http": {"server_config": {"timeout_keep_alive": timeout_keep_alive}} + } + }, + "merge", + ): + server = http_server.RuntimeHTTPServer() + assert server.server.config.timeout_keep_alive == timeout_keep_alive - training_response = client.post( - f"/api/v1/OtherTaskOtherModuleTrain", - json=json_input, - ) - # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) - assert training_response.status_code == 200, training_json_response - assert (training_id := training_json_response["training_id"]) - assert training_json_response["model_name"] == model_name - # assert trained model - result = MODEL_MANAGER.get_model_future(training_id).load() - assert result.batch_size == 64 - assert ( - result.MODULE_CLASS - == "sample_lib.modules.other_task.other_implementation.OtherModule" - ) +def test_uvicorn_server_config_invalid_tls_overlap(): + """Make sure uvicorn TLS arguments cannot be set if TLS is enabled in caikit + config + """ + with temp_config( + { + "runtime": { + "http": { + "server_config": { + "ssl_keyfile": "/some/file.pem", + } + } + } + }, + "merge", + ): + with generate_tls_configs(port=1234, tls=True, mtls=True): + with pytest.raises(ValueError): + http_server.RuntimeHTTPServer() - # register the newly trained model for inferencing - register_trained_model( - runtime_http_server.global_predict_servicer, - model_name, - training_id, - ) - # test inferencing on new model - json_input_inference = {"model_id": model_name, "inputs": {"name": "world"}} - response = client.post( - f"/api/v1/task/other", - json=json_input_inference, - ) - json_response = json.loads(response.content.decode(response.default_encoding)) - assert response.status_code == 200, json_response - assert json_response["farewell"] == "goodbye: world 64 times" +def test_uvicorn_server_config_invalid_kwarg_overlap(): + """Make sure uvicorn config can't be set for configs that caikit manages""" + with temp_config( + { + "runtime": { + "http": { + "server_config": { + "log_level": "debug", + } + } + } + }, + "merge", + ): + with pytest.raises(ValueError): + http_server.RuntimeHTTPServer() diff --git a/tests/runtime/http_server/test_request_aborter.py b/tests/runtime/http_server/test_request_aborter.py new file mode 100644 index 000000000..0f5a451f1 --- /dev/null +++ b/tests/runtime/http_server/test_request_aborter.py @@ -0,0 +1,110 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +import asyncio +import datetime +import threading +import time + +# Third Party +from fastapi import FastAPI, Request +from requests.exceptions import ReadTimeout +import pytest +import requests +import uvicorn + +# Local +from caikit.runtime.http_server.request_aborter import HttpRequestAborter +from tests.runtime.work_management.test_call_aborter import StubAbortableContext + + +def get_time_remaining(start_time: datetime.datetime, timeout: int = 10) -> float: + now = datetime.datetime.now() + return ((start_time + datetime.timedelta(seconds=timeout)) - now).total_seconds() + + +def test_request_aborter(open_port): + # Get start time to ensure test doesn't hang + start_time = datetime.datetime.now() + + # Create FastAPI app for testing + app = FastAPI() + + # Initialize synchronization variables for tracking request process + abort_context = StubAbortableContext() + request_finished = threading.Event() + + # Define an endpoint that sleeps until the client disconnects. + @app.get("/test_abort") + async def test_aborter(context: Request): + # Create aborter and add parent event + TEST_ABORTER = HttpRequestAborter(context, poll_time=0.001) + TEST_ABORTER.set_context(abort_context) + + # Assign TEST_ABORTER to the parent function. This allows the test to have + # access to this object without using globals + test_request_aborter.TEST_ABORTER = TEST_ABORTER + + # Wait for client to disconnect + while not TEST_ABORTER.must_abort() and get_time_remaining(start_time) > 0: + await asyncio.sleep(0.001) + + request_finished.set() + + # Start up a local uvicorn server in a thread + config = uvicorn.Config( + app, + host="0.0.0.0", + port=open_port, + log_level="trace", + log_config=None, + timeout_graceful_shutdown=None, + ) + server = uvicorn.Server(config=config) + server_thread = threading.Thread(target=server.run) + server_thread.start() + + server_exception = None + try: + # Wait for uvicorn to start + while not server.started: + if get_time_remaining(start_time) < 0: + raise TimeoutError("Server did not start in time") + + time.sleep(0.001) + + # Try the endpoint but timeout after 10ms + with pytest.raises(ReadTimeout): + requests.get( + f"http://localhost:{open_port}/test_abort", + timeout=0.01, + ) + + # Wait for the request to finish/abort + request_finished.wait(get_time_remaining(start_time)) + + # Assert the request aborter actually aborted + assert test_request_aborter.TEST_ABORTER.must_abort() + assert abort_context.aborted + assert request_finished.is_set() + + except Exception as exc: + server_exception = exc + finally: + # Clean up the server + server.should_exit = True + server_thread.join() + + if server_exception: + raise server_exception diff --git a/tests/runtime/model_management/test_model_loader.py b/tests/runtime/model_management/test_model_loader.py index d17da8af0..592d6b2f7 100644 --- a/tests/runtime/model_management/test_model_loader.py +++ b/tests/runtime/model_management/test_model_loader.py @@ -114,7 +114,7 @@ def test_load_invalid_model_error_response(model_loader): local_model_path=Fixtures.get_bad_model_archive_path(), model_type="not_real", ).wait() - assert context.value.status_code == grpc.StatusCode.INTERNAL + assert context.value.status_code == grpc.StatusCode.NOT_FOUND assert model_id in context.value.message @@ -308,7 +308,7 @@ def test_load_model_without_waiting_deferred_error(model_loader): ) with pytest.raises(CaikitRuntimeException) as context: loaded_model.model() - assert context.value.status_code == grpc.StatusCode.INTERNAL + assert context.value.status_code == grpc.StatusCode.NOT_FOUND assert model_id in context.value.message diff --git a/tests/runtime/model_management/test_model_manager.py b/tests/runtime/model_management/test_model_manager.py index 430967256..27825da82 100644 --- a/tests/runtime/model_management/test_model_manager.py +++ b/tests/runtime/model_management/test_model_manager.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from functools import partial from tempfile import TemporaryDirectory +from typing import Optional from unittest.mock import MagicMock, patch import os import shutil @@ -26,14 +27,21 @@ import grpc import pytest +# First Party +from aconfig.aconfig import Config +import aconfig + # Local from caikit import get_config +from caikit.core.model_management import ModelFinderBase +from caikit.core.model_management.local_model_initializer import LocalModelInitializer from caikit.core.model_manager import ModelManager as CoreModelManager -from caikit.core.modules import ModuleBase +from caikit.core.modules import ModuleBase, ModuleConfig from caikit.runtime.model_management.loaded_model import LoadedModel from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException from caikit.runtime.utils.import_util import get_dynamic_module +from sample_lib.data_model import SampleInputType from tests.conftest import TempFailWrapper, random_test_id, temp_config from tests.core.helpers import TestFinder from tests.fixtures import Fixtures @@ -110,7 +118,7 @@ def test_load_model_ok_response(): model_id=model_id, local_model_path=Fixtures.get_good_model_path(), model_type=Fixtures.get_good_model_type(), - ) + ).size() assert model_size > 0 @@ -127,7 +135,7 @@ def test_load_model_no_size_update(): model_id=model_id, local_model_path=Fixtures.get_good_model_path(), model_type=Fixtures.get_good_model_type(), - ) + ).size() assert model_size > 0 loaded_model = MODEL_MANAGER.loaded_models[model_id] assert loaded_model.size() == model_size @@ -151,7 +159,8 @@ def test_load_local_models(): assert "model-does-not-exist.zip" not in MODEL_MANAGER.loaded_models.keys() -def test_model_manager_loads_local_models_on_init(): +@pytest.mark.parametrize("wait", [True, False]) +def test_model_manager_loads_local_models_on_init(wait): with TemporaryDirectory() as tempdir: shutil.copytree(Fixtures.get_good_model_path(), os.path.join(tempdir, "model1")) shutil.copy( @@ -160,7 +169,13 @@ def test_model_manager_loads_local_models_on_init(): ) ModelManager._ModelManager__instance = None with temp_config( - {"runtime": {"local_models_dir": tempdir}}, merge_strategy="merge" + { + "runtime": { + "local_models_dir": tempdir, + "wait_for_initial_model_loads": wait, + }, + }, + merge_strategy="merge", ): MODEL_MANAGER = ModelManager() @@ -169,6 +184,11 @@ def test_model_manager_loads_local_models_on_init(): assert "model2.zip" in MODEL_MANAGER.loaded_models.keys() assert "model-does-not-exist.zip" not in MODEL_MANAGER.loaded_models.keys() + # Make sure that the loaded model can be retrieved and run + for model_name in ["model1", "model2.zip"]: + model = MODEL_MANAGER.retrieve_model(model_name) + model.run(SampleInputType("hello")) + def test_load_model_error_response(): """Test load model's model does not exist when the loader throws""" @@ -434,7 +454,7 @@ def test_model_manager_disk_caching_periodic_sync(good_model_path): """Make sure that when using disk caching, the manager periodically syncs its loaded models based on their presence in the cache """ - purge_period = 0.001 + purge_period = 0.002 with TemporaryDirectory() as cache_dir: with non_singleton_model_managers( 2, @@ -485,6 +505,142 @@ def test_model_manager_disk_caching_periodic_sync(good_model_path): assert mgr_one_unloaded and mgr_two_unloaded +def test_lazy_load_of_large_model(good_model_path): + """Test that a large model that is actively being written to disk is not incorrectly loaded + too soon by the lazy loading poll + """ + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + + # Start with a valid model + model_name = os.path.basename(good_model_path) + model_cache_path = os.path.join(cache_dir, model_name) + assert not os.path.exists(model_cache_path) + shutil.copytree(good_model_path, model_cache_path) + + # Then kick off a thread that will start writing a large file inside this model dir. + # This simulates uploading a large model artifact + def write_big_file(path: str, stop_event: threading.Event): + big_file = os.path.join(path, "big_model_artifact.txt") + with open(big_file, "w") as bf: + while not stop_event.is_set(): + bf.write("This is a big file\n" * 1000) + + stop_write_event = threading.Event() + writer_thread = threading.Thread( + target=write_big_file, args=(model_cache_path, stop_write_event) + ) + writer_thread.start() + + try: + # Trigger the periodic sync and make sure the model is NOT loaded + assert model_name not in manager.loaded_models + manager.sync_local_models(wait=True) + assert model_name not in manager.loaded_models + + # Stop the model writing thread (Finish the model upload) + stop_write_event.set() + writer_thread.join() + + # Re-trigger the sync and make sure the model is loaded this time + manager.sync_local_models(wait=True) + assert model_name in manager.loaded_models + + finally: + stop_write_event.set() + writer_thread.join() + + +def test_nested_local_model_load_unload(good_model_path): + """Test that a model can be loaded in a subdirectory of the local_models_dir + and that the periodic sync does not unload the model. + """ + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + + # Copy the model into a nested model directory + model_name = os.path.join("parent", os.path.basename(good_model_path)) + model_cache_path = os.path.join(cache_dir, model_name) + assert not os.path.exists(model_cache_path) + shutil.copytree(good_model_path, model_cache_path) + + # Trigger the periodic sync and make sure the model is NOT loaded + assert model_name not in manager.loaded_models + manager.sync_local_models(wait=True) + assert model_name not in manager.loaded_models + + # Explicitly ask to load the nested model name to trigger the lazy + # load + model = manager.retrieve_model(model_name) + assert model + assert model_name in manager.loaded_models + + # Re-trigger the sync and make sure the model does not get unloaded + manager.sync_local_models(wait=True) + assert model_name in manager.loaded_models + + +def test_model_unload_race(good_model_path): + """Test that if a model gets unloaded _while_ it's actively being loaded + (before retrieve_model completes, but after load_model completes), no + exception is raised. + """ + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + + # Copy the model to the local_models_dir + model_id = random_test_id() + model_cache_path = os.path.join(cache_dir, model_id) + shutil.copytree(good_model_path, model_cache_path) + + # Patch the manager's load_model to immediately unload the model + orig_load_model = manager.load_model + + def load_and_unload_model(self, model_id: str, *args, **kwargs): + res = orig_load_model(model_id, *args, **kwargs) + manager.unload_model(model_id) + return res + + with patch.object(manager.__class__, "load_model", load_and_unload_model): + + # Retrieve the model and make sure there's no error + assert manager.retrieve_model(model_id) + assert model_id not in manager.loaded_models + + def test_load_local_model_deleted_dir(): """Make sure losing the local_models_dir out from under a running manager doesn't kill the whole thing @@ -590,7 +746,7 @@ def test_load_model(): model_size = MODEL_MANAGER.load_model( model_id, ANY_MODEL_PATH, ANY_MODEL_TYPE - ) + ).size() assert expected_model_size == model_size mock_loader.load_model.assert_called_once() call_args = mock_loader.load_model.call_args @@ -599,7 +755,6 @@ def test_load_model(): ANY_MODEL_PATH, ANY_MODEL_TYPE, ) - assert call_args.kwargs["aborter"] is None assert "fail_callback" in call_args.kwargs mock_sizer.get_model_size.assert_called_once_with( model_id, ANY_MODEL_PATH, ANY_MODEL_TYPE @@ -754,12 +909,12 @@ def test_reload_partially_loaded(): mock_loader.load_model.return_value = loaded_model model_size = MODEL_MANAGER.load_model( model_id, ANY_MODEL_PATH, ANY_MODEL_TYPE, wait=False - ) + ).size() assert model_size == special_model_size assert ( MODEL_MANAGER.load_model( model_id, ANY_MODEL_PATH, ANY_MODEL_TYPE, wait=False - ) + ).size() == special_model_size ) @@ -918,3 +1073,177 @@ def test_lazy_load_handles_temporary_errors(): assert manager._lazy_sync_timer is None model = manager.retrieve_model(model_name) assert model + + +def test_lazy_load_true_local_models_dir_valid(): + """When lazy_load_local_models is True and local_models_dir exists. + Check that the local_models_dir is pointing to the correct location + """ + + with TemporaryDirectory() as cache_dir: + + ModelManager._ModelManager__instance = None + with temp_config( + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + } + }, + merge_strategy="merge", + ): + MODEL_MANAGER = ModelManager() + assert len(MODEL_MANAGER.loaded_models) == 0 + assert MODEL_MANAGER._local_models_dir == cache_dir + + +def test_lazy_load_true_local_models_dir_invalid(): + """When lazy_load_local_models is True and local_models_dir does not exist. + Raise ValueError with an appropriate message + """ + + with TemporaryDirectory() as cache_dir: + + with pytest.raises( + ValueError, + match=( + "runtime.local_models_dir must be a valid path" + " if set with runtime.lazy_load_local_models. " + "Provided path: invalid" + ), + ): + + ModelManager._ModelManager__instance = None + with temp_config( + { + "runtime": { + "local_models_dir": "invalid", + "lazy_load_local_models": True, + } + }, + merge_strategy="merge", + ): + MODEL_MANAGER = ModelManager() + + +def test_lazy_load_true_local_models_dir_none(): + """When lazy_load_local_models is True and local_models_dir is not set in the config. + Raise ValueError with an appropriate message + """ + + with TemporaryDirectory() as cache_dir: + + with pytest.raises( + ValueError, + match=( + "runtime.local_models_dir must be set" + " if using runtime.lazy_load_local_models. " + ), + ): + + ModelManager._ModelManager__instance = None + with temp_config( + { + "runtime": { + "local_models_dir": None, + "lazy_load_local_models": True, + } + }, + merge_strategy="merge", + ): + MODEL_MANAGER = ModelManager() + + +def test_lazy_load_false_local_models_dir_valid(): + """When lazy_load_local_models is False and local_models_dir exists. + Check that the local_models_dir is pointing to the correct location + """ + + with TemporaryDirectory() as cache_dir: + + ModelManager._ModelManager__instance = None + with temp_config( + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": False, + } + }, + merge_strategy="merge", + ): + MODEL_MANAGER = ModelManager() + assert len(MODEL_MANAGER.loaded_models) == 0 + assert MODEL_MANAGER._local_models_dir == cache_dir + + +def test_lazy_load_false_local_models_dir_invalid(): + """When lazy_load_local_models is False and local_models_dir does not exist. + Check that the local_models_dir is False / Empty + """ + + with TemporaryDirectory() as cache_dir: + + ModelManager._ModelManager__instance = None + with temp_config( + { + "runtime": { + "local_models_dir": "", + "lazy_load_local_models": False, + } + }, + merge_strategy="merge", + ): + MODEL_MANAGER = ModelManager() + assert len(MODEL_MANAGER.loaded_models) == 0 + assert not MODEL_MANAGER._local_models_dir + + +class NoModelFinder(ModelFinderBase): + name = "NOMODEL" + + def __init__(self, config: Config, instance_name: str): + super().__init__(config, instance_name) + + def find_model(self, model_path: str, **kwargs) -> ModuleConfig: + raise FileNotFoundError(f"Unable to find model {model_path}") + + +def test_load_model_custom_finder(): + """Test to ensure loading model works with custom finder""" + bad_finder = NoModelFinder(aconfig.Config({}), "bad_instance") + + model_id = random_test_id() + with pytest.raises(CaikitRuntimeException) as exp: + MODEL_MANAGER.load_model( + model_id=model_id, + local_model_path=Fixtures.get_good_model_path(), + model_type=Fixtures.get_good_model_type(), + finder=bad_finder, + ) + assert exp.value.status_code == grpc.StatusCode.NOT_FOUND + + +class CustomParamInitializer(LocalModelInitializer): + name = "CUSTOMPARAM" + + def init(self, model_config: ModuleConfig, **kwargs) -> ModuleBase: + module = super().init(model_config, **kwargs) + module.custom_param = True + return module + + +def test_load_model_custom_initializer(): + """Test to ensure loading model works with custom initializer""" + + custom_param_initializer = CustomParamInitializer( + aconfig.Config({}), "custom_param" + ) + model_id = random_test_id() + model = MODEL_MANAGER.load_model( + model_id=model_id, + local_model_path=Fixtures.get_good_model_path(), + model_type=Fixtures.get_good_model_type(), + initializer=custom_param_initializer, + ).model() + assert model + assert model.custom_param diff --git a/tests/runtime/service_generation/test_create_service.py b/tests/runtime/service_generation/test_create_service.py index b930343d8..a8a7b62cb 100644 --- a/tests/runtime/service_generation/test_create_service.py +++ b/tests/runtime/service_generation/test_create_service.py @@ -61,8 +61,9 @@ def run(self, sample_input: SampleInputType) -> SampleOutputType: # SampleModule also implements `SampleTask` rpcs = create_inference_rpcs([NewModule, SampleModule]) assert len(rpcs) == 3 # SampleModule has 3 streaming flavors - assert NewModule in rpcs[0].module_list + assert NewModule in rpcs[1].module_list assert SampleModule in rpcs[0].module_list + assert SampleModule in rpcs[2].module_list def test_create_inference_rpcs_includes_backend_modules(): @@ -257,12 +258,15 @@ def test_create_inference_rpcs_for_multiple_modules_of_same_type(): # 4 RPCs, SampleModule and SamplePrimitiveModule have task SampleTask with 3 flavors for # streaming, OtherModule has task OtherTask + # and the rpcs should be sorted by name (ie: ['BidiStreamingSampleTaskPredict', 'OtherTaskPredict', + # 'SampleTaskPredict', 'ServerStreamingSampleTaskPredict']) assert len(rpcs) == 4 - assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list - assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[0].module_list - assert sample_lib.modules.sample_task.SampleModule in rpcs[1].module_list + print("rpcs are: ", [x.name for x in rpcs]) assert sample_lib.modules.sample_task.SampleModule in rpcs[2].module_list - assert sample_lib.modules.other_task.OtherModule in rpcs[-1].module_list + assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[2].module_list + assert sample_lib.modules.sample_task.SampleModule in rpcs[3].module_list + assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list + assert sample_lib.modules.other_task.OtherModule in rpcs[1].module_list def test_create_inference_rpcs_respects_sorted_order_by_module_id(): @@ -275,20 +279,21 @@ def test_create_inference_rpcs_respects_sorted_order_by_module_id(): # 3 RPCs, SampleModule, SamplePrimitiveModule and ListModule have task SampleTask with 3 flavors for # streaming + # and the rpcs should be sorted by name (ie ['BidiStreamingSampleTaskPredict', 'SampleTaskPredict', 'ServerStreamingSampleTaskPredict']) assert len(rpcs) == 3 assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list - assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[0].module_list + assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[1].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[1].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[2].module_list - assert sample_lib.modules.sample_task.ListModule in rpcs[0].module_list + assert sample_lib.modules.sample_task.ListModule in rpcs[1].module_list - # check for alphabetical order of modules in rpcs[0] by Module ID + # Within rpc SampleTaskPredict, check for alphabetical order of modules by Module ID # this should always be deterministic - assert sample_lib.modules.sample_task.SampleModule == rpcs[0].module_list[0] + assert sample_lib.modules.sample_task.SampleModule == rpcs[1].module_list[0] assert ( - sample_lib.modules.sample_task.SamplePrimitiveModule == rpcs[0].module_list[1] + sample_lib.modules.sample_task.SamplePrimitiveModule == rpcs[1].module_list[1] ) - assert sample_lib.modules.sample_task.ListModule == rpcs[0].module_list[-1] + assert sample_lib.modules.sample_task.ListModule == rpcs[1].module_list[-1] def test_create_inference_rpcs_removes_modules_with_no_task(): diff --git a/tests/runtime/service_generation/test_data_stream_source.py b/tests/runtime/service_generation/test_data_stream_source.py index 3564bc0ca..62c91b018 100644 --- a/tests/runtime/service_generation/test_data_stream_source.py +++ b/tests/runtime/service_generation/test_data_stream_source.py @@ -104,6 +104,8 @@ def test_pickle_round_trip_primitive(): round_trip = pickle.loads(pickle.dumps(inst)) assert round_trip.to_dict() == inst.to_dict() + validate_data_stream(round_trip, 2, float) + def test_pickle_round_trip_data_model(): """Make sure that a source wrapping a data model object can be round-tripped @@ -122,6 +124,17 @@ class Foo(DataObjectBase): round_trip = pickle.loads(pickle.dumps(inst)) assert round_trip.to_dict() == inst.to_dict() + validate_data_stream(round_trip, 2, Foo) + + +def test_pickle_round_trip_file(sample_json_file): + stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType + data_stream = stream_type(file=FileReference(filename=sample_json_file)) + + round_trip = pickle.loads(pickle.dumps(data_stream)) + + validate_data_stream(round_trip, 2, SampleTrainingType) + def test_data_stream_source_as_data_stream(): """Make sure that a DataStreamSource works exactly like a DataStream""" @@ -544,3 +557,20 @@ def test_s3_not_implemented(): ) as e: for val in ds: _ = val + + +def test_datastream_sources_not_repeatedly_read(sample_json_file): + """This test ensures that `to_data_stream` is only called once on a single instance of a + DataStreamSource. This allows source plugin authors to control how data is cached for the + life of the stream""" + stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType + data_stream = stream_type(file=FileReference(filename=sample_json_file)) + + # Read the datastream as normal + validate_data_stream(data_stream, 2, SampleTrainingType) + + # Delete the set source field so that .to_data_stream does not know how to find a source + data_stream.file = None + + # Check that we can still read through the stream + validate_data_stream(data_stream, 2, SampleTrainingType) diff --git a/tests/runtime/servicers/test_global_predict_servicer_impl.py b/tests/runtime/servicers/test_global_predict_servicer_impl.py index 70b9ab54e..40843c986 100644 --- a/tests/runtime/servicers/test_global_predict_servicer_impl.py +++ b/tests/runtime/servicers/test_global_predict_servicer_impl.py @@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch # Local -from caikit.core.data_model import ProducerId +from caikit.core.data_model import DataStream, ProducerId from caikit.runtime.service_factory import get_inference_request from sample_lib.data_model.sample import GeoSpatialTask from sample_lib.modules import MultiTaskModule, SecondTask @@ -134,6 +134,24 @@ def test_global_predict_works_for_unary_rpcs( assert response == HAPPY_PATH_RESPONSE +def test_global_predict_explicit_inference_function( + sample_inference_service, + sample_predict_servicer, + sample_task_model_id, + sample_task_unary_rpc, +): + """Calling predict_model can explicitly select the inference function""" + predict_class = get_inference_request(SampleTask) + request_name = sample_task_unary_rpc.request.name + response = sample_predict_servicer.predict_model( + request_name=request_name, + model_id=sample_task_model_id, + inference_func_name="run_stream_out", + sample_input=HAPPY_PATH_INPUT_DM, + ) + assert isinstance(response, DataStream) + + def test_global_predict_works_on_bidirectional_streaming_rpcs( sample_inference_service, sample_predict_servicer, sample_task_model_id ): @@ -271,7 +289,7 @@ def run(self, *args, **kwargs): assert dummy_model.started.wait(2) # Simulate a timeout or client abort context.cancel() - predict_thread.join(10) + predict_thread.join(2) # Make sure the prediction actually stopped assert not predict_thread.is_alive() diff --git a/tests/runtime/servicers/test_model_runtime_servicer_impl.py b/tests/runtime/servicers/test_model_runtime_servicer_impl.py index d7f7b8d51..20be79e16 100644 --- a/tests/runtime/servicers/test_model_runtime_servicer_impl.py +++ b/tests/runtime/servicers/test_model_runtime_servicer_impl.py @@ -33,79 +33,72 @@ from tests.fixtures import Fixtures -class TestModelRuntimeServicerImpl(unittest.TestCase): - """This test suite tests the ModelRuntimeServicerImpl class""" - - def setUp(self): - """This method runs before each test begins to run""" - self.servicer = ModelRuntimeServicerImpl() - - def test_model_load_sets_per_model_concurrency(self): - model = "test-any-model-id" - # Grab a model type that has some max concurrency set - model_type = list( - get_config().inference_plugin.model_mesh.max_model_concurrency_per_type.keys() - )[0] - request = model_runtime_pb2.LoadModelRequest( - modelId=model, modelType=model_type - ) - context = Fixtures.build_context(model) - - expected_concurrency = ( - get_config().inference_plugin.model_mesh.max_model_concurrency_per_type[ - model_type - ] - ) - mock_manager = MagicMock() - mock_manager.load_model.return_value = 1 - - with patch.object(self.servicer, "model_manager", mock_manager): - response = self.servicer.loadModel(request, context) - self.assertEqual(expected_concurrency, response.maxConcurrency) - - def test_model_load_sets_default_max_model_concurrency(self): - model = "test-any-model-id" - model_type = "some-fake-model-type" - request = model_runtime_pb2.LoadModelRequest( - modelId=model, modelType=model_type - ) - context = Fixtures.build_context(model) - - expected_concurrency = ( - get_config().inference_plugin.model_mesh.max_model_concurrency - ) - mock_manager = MagicMock() - mock_manager.load_model.return_value = 1 - - with patch.object(self.servicer, "model_manager", mock_manager): - response = self.servicer.loadModel(request, context) - self.assertEqual(expected_concurrency, response.maxConcurrency) - - def test_load_model_aborts(self): - """ModelRuntimeServicer.loadModel will abort a long-running load""" - model = "test-any-model-id" - request = model_runtime_pb2.LoadModelRequest(modelId=model) - context = Fixtures.build_context(model) - - mock_manager = MagicMock() - started = Event() - - def never_return(*args, **kwargs): - started.set() - while True: - time.sleep(0.01) - - mock_manager.load_model.side_effect = never_return - load_thread = Thread(target=self.servicer.loadModel, args=(request, context)) - - with catch_threading_exception() as cm: - with patch.object(self.servicer, "model_manager", mock_manager): - load_thread.start() - started.wait() - context.cancel() - load_thread.join(10) - - self.assertFalse(load_thread.is_alive()) - - # Make sure the correct exception was raised - assert cm.exc_type == AbortedException +def test_model_load_sets_per_model_concurrency(model_runtime_servicer): + model = "test-any-model-id" + # Grab a model type that has some max concurrency set + model_type = list( + get_config().inference_plugin.model_mesh.max_model_concurrency_per_type.keys() + )[0] + request = model_runtime_pb2.LoadModelRequest(modelId=model, modelType=model_type) + context = Fixtures.build_context(model) + + expected_concurrency = ( + get_config().inference_plugin.model_mesh.max_model_concurrency_per_type[ + model_type + ] + ) + mock_manager = MagicMock() + mock_manager.load_model.size.return_value = 1 + + with patch.object(model_runtime_servicer, "model_manager", mock_manager): + response = model_runtime_servicer.loadModel(request, context) + assert expected_concurrency == response.maxConcurrency + + +def test_model_load_sets_default_max_model_concurrency(model_runtime_servicer): + model = "test-any-model-id" + model_type = "some-fake-model-type" + request = model_runtime_pb2.LoadModelRequest(modelId=model, modelType=model_type) + context = Fixtures.build_context(model) + + expected_concurrency = ( + get_config().inference_plugin.model_mesh.max_model_concurrency + ) + mock_manager = MagicMock() + mock_manager.load_model.size.return_value = 1 + + with patch.object(model_runtime_servicer, "model_manager", mock_manager): + response = model_runtime_servicer.loadModel(request, context) + assert expected_concurrency == response.maxConcurrency + + +def test_load_model_aborts(model_runtime_servicer): + """ModelRuntimeServicer.loadModel will abort a long-running load""" + model = "test-any-model-id" + request = model_runtime_pb2.LoadModelRequest(modelId=model) + context = Fixtures.build_context(model) + + mock_manager = MagicMock() + started = Event() + + def never_return(*args, **kwargs): + started.set() + while True: + time.sleep(0.01) + + mock_manager.load_model.side_effect = never_return + load_thread = Thread( + target=model_runtime_servicer.loadModel, args=(request, context) + ) + + with catch_threading_exception() as cm: + with patch.object(model_runtime_servicer, "model_manager", mock_manager): + load_thread.start() + started.wait() + context.cancel() + load_thread.join(10) + + assert not load_thread.is_alive() + + # Make sure the correct exception was raised + assert cm.exc_type == AbortedException diff --git a/tests/runtime/servicers/test_training_management_servicer.py b/tests/runtime/servicers/test_training_management_servicer.py index 7f36510b9..54e745e94 100644 --- a/tests/runtime/servicers/test_training_management_servicer.py +++ b/tests/runtime/servicers/test_training_management_servicer.py @@ -206,6 +206,7 @@ def test_training_cancel_on_correct_id(training_management_servicer): ) # training number 2 should still complete + model_future_2.wait() request_2 = TrainingInfoRequest(training_id=model_future_2.id).to_proto() response_2 = training_management_servicer.GetTrainingStatus(request_2, context=None) assert response_2.state == TrainingStatus.COMPLETED.value diff --git a/tests/runtime/test_caikit_health_probe.py b/tests/runtime/test_caikit_health_probe.py new file mode 100644 index 000000000..c42cca7d1 --- /dev/null +++ b/tests/runtime/test_caikit_health_probe.py @@ -0,0 +1,242 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the uniform health probe + +🌶️🌶️🌶️ This test relies on test infrastructure in caikit.runtime, so the test +needs to live inside tests/runtime even though the functionality being tested is +not. If this is moved to the top of tests, the runtime test infra boots up too +early causing some of the core tests to fail! +""" +# Standard +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from unittest import mock +import os +import shlex +import subprocess +import sys + +# Third Party +import pytest +import tls_test_tools + +# First Party +from caikit_health_probe import __main__ as caikit_health_probe +import alog + +# Local +from caikit import get_config +from tests.conftest import temp_config +from tests.runtime.conftest import runtime_grpc_test_server, runtime_http_test_server +from tests.runtime.http_server.test_http_server import generate_tls_configs + +## Helpers ##################################################################### + +log = alog.use_channel("TEST") + + +@contextmanager +def maybe_runtime_grpc_test_server(*args, **kwargs): + if get_config().runtime.grpc.enabled: + with runtime_grpc_test_server(*args, **kwargs) as grpc_server: + yield grpc_server + else: + yield + + +@contextmanager +def maybe_runtime_http_test_server(*args, **kwargs): + if get_config().runtime.http.enabled: + with runtime_http_test_server(*args, **kwargs) as http_server: + yield http_server + else: + yield + + +@contextmanager +def temp_probe_config(*args, **kwargs): + with temp_config(*args, **kwargs) as the_config: + get_config_mock = mock.MagicMock(return_value=the_config) + with mock.patch.object(caikit_health_probe, "get_config", get_config_mock): + yield + + +class TlsMode(Enum): + INSECURE = 0 + TLS = 1 + MTLS = 3 + + +class ServerMode(Enum): + HTTP = 0 + GRPC = 1 + BOTH = 2 + + +@dataclass +class ProbeTestConfig: + tls_mode: TlsMode + server_mode: ServerMode + # TLS blobs passed as inline strings instead of files + inline: bool = False + # Run the unix socket grpc server + unix_socket: bool = True + # Put "localhost" in the SAN list for the server's cert + localhost_in_cert: bool = True + # Use a common CA for client and server certs (mTLS only) + common_client_ca: bool = True + # Whether the test should eventually become healthy + should_become_healthy: bool = True + + +## Tests ####################################################################### + + +@pytest.mark.parametrize( + "test_config", + [ + # Insecure + ProbeTestConfig(TlsMode.INSECURE, ServerMode.HTTP), + ProbeTestConfig(TlsMode.INSECURE, ServerMode.GRPC), + ProbeTestConfig(TlsMode.INSECURE, ServerMode.BOTH), + # TLS + ProbeTestConfig(TlsMode.TLS, ServerMode.HTTP), + ProbeTestConfig(TlsMode.TLS, ServerMode.GRPC), + ProbeTestConfig(TlsMode.TLS, ServerMode.BOTH), + ProbeTestConfig(TlsMode.TLS, ServerMode.BOTH, inline=True), + ProbeTestConfig(TlsMode.TLS, ServerMode.BOTH, localhost_in_cert=False), + # mTLS + ProbeTestConfig(TlsMode.MTLS, ServerMode.HTTP), + ProbeTestConfig(TlsMode.MTLS, ServerMode.GRPC), + ProbeTestConfig(TlsMode.MTLS, ServerMode.BOTH), + ProbeTestConfig(TlsMode.MTLS, ServerMode.BOTH, inline=True), + ProbeTestConfig(TlsMode.MTLS, ServerMode.BOTH, localhost_in_cert=False), + ProbeTestConfig(TlsMode.MTLS, ServerMode.BOTH, common_client_ca=False), + # Invalid configs that never pass + ProbeTestConfig( + TlsMode.TLS, + ServerMode.GRPC, + localhost_in_cert=False, + unix_socket=False, + should_become_healthy=False, + ), + ProbeTestConfig( + TlsMode.TLS, + ServerMode.BOTH, + localhost_in_cert=False, + unix_socket=False, + should_become_healthy=False, + ), + ProbeTestConfig( + TlsMode.MTLS, + ServerMode.GRPC, + localhost_in_cert=False, + unix_socket=False, + should_become_healthy=False, + ), + ], +) +def test_readiness_probe(test_config: ProbeTestConfig): + """Test all of the different ways that the servers could be running""" + with alog.ContextLog(log.info, "---LOG CONFIG: %s---", test_config): + # Get ports for both servers + http_port = tls_test_tools.open_port() + grpc_port = tls_test_tools.open_port() + + # Set up SAN lists if not putting "localhost" in + server_sans, client_sans = None, None + if not test_config.localhost_in_cert: + server_sans = ["foo.bar"] + client_sans = ["baz.bat"] + + # Set up tls values if needed + with generate_tls_configs( + port=http_port, + tls=test_config.tls_mode == TlsMode.TLS, + mtls=test_config.tls_mode == TlsMode.MTLS, + inline=test_config.inline, + separate_client_ca=not test_config.common_client_ca, + server_sans=server_sans, + client_sans=client_sans, + ) as config_overrides: + with temp_probe_config( + { + "runtime": { + "grpc": { + "port": grpc_port, + "enabled": test_config.server_mode + in [ServerMode.GRPC, ServerMode.BOTH], + "unix_socket_path": os.path.join( + config_overrides["use_in_test"]["workdir"], + "grpc.sock", + ) + if test_config.unix_socket + else None, + }, + "http": { + "enabled": test_config.server_mode + in [ServerMode.HTTP, ServerMode.BOTH], + }, + } + }, + "merge", + ): + # Health probe fails with no servers booted + assert not caikit_health_probe.readiness_probe() + # If booting the gRPC server, do so + with maybe_runtime_grpc_test_server(grpc_port): + # If only running gRPC, health probe should pass + assert caikit_health_probe.readiness_probe() == ( + test_config.should_become_healthy + and test_config.server_mode == ServerMode.GRPC + ) + # If booting the HTTP server, do so + with maybe_runtime_http_test_server( + http_port, + tls_config_override=config_overrides, + check_readiness=test_config.should_become_healthy, + ): + # Probe should always pass with both possible servers up + assert ( + caikit_health_probe.readiness_probe() + == test_config.should_become_healthy + ) + + +@pytest.mark.parametrize( + ["proc_identifier", "expected"], + [(None, True), ("caikit.runt", True), ("foobar", False)], +) +def test_liveness_probe(proc_identifier, expected): + """Test the logic for determining if the server process is alive""" + cmd = f"{sys.executable} -m caikit.runtime" + args = [] if proc_identifier is None else [proc_identifier] + + # Liveness should fail if process is not booted + assert not caikit_health_probe.liveness_probe(*args) + + proc = None + try: + # Start the process + proc = subprocess.Popen(shlex.split(cmd)) + + # Liveness should pass/fail as expected + assert caikit_health_probe.liveness_probe(*args) == expected + + finally: + # Kill the process if it started + if proc is not None and proc.poll() is None: + proc.kill() diff --git a/tests/runtime/test_dump_services.py b/tests/runtime/test_dump_services.py index b4d0e412d..4cb7f9942 100644 --- a/tests/runtime/test_dump_services.py +++ b/tests/runtime/test_dump_services.py @@ -16,35 +16,46 @@ import shutil import tempfile +# Third Party +import pytest + # First Party import alog # Local from caikit.runtime.dump_services import dump_grpc_services, dump_http_services +from tests.conftest import ARM_ARCH, PROTOBUF_VERSION ## Helpers ##################################################################### log = alog.use_channel("TEST-DUMP-I") +@pytest.mark.skipif( + PROTOBUF_VERSION < 4 and ARM_ARCH, reason="protobuf 3 serialization bug" +) def test_dump_grpc_services_dir_exists(): with tempfile.TemporaryDirectory() as workdir: - dump_grpc_services(workdir) + dump_grpc_services(workdir, False) assert os.path.exists(workdir) for file in os.listdir(workdir): assert file.endswith(".proto") +@pytest.mark.skipif( + PROTOBUF_VERSION < 4 and ARM_ARCH, reason="protobuf 3 serialization bug" +) def test_dump_grpc_services_dir_does_not_exist(): - fake_dir = "fake_dir" - dump_grpc_services(fake_dir) - assert os.path.exists(fake_dir) + with tempfile.TemporaryDirectory() as workdir: + fake_dir = os.path.join(workdir, "fake_dir") + dump_grpc_services(fake_dir, False) + assert os.path.exists(fake_dir) - for file in os.listdir(fake_dir): - assert file.endswith(".proto") + for file in os.listdir(fake_dir): + assert file.endswith(".proto") - shutil.rmtree(fake_dir) + shutil.rmtree(fake_dir) def test_dump_http_services_dir_exists(): @@ -58,12 +69,11 @@ def test_dump_http_services_dir_exists(): def test_dump_http_services_dir_does_not_exist(): - fake_dir = "fake_dir" - dump_http_services(fake_dir) - assert os.path.exists(fake_dir) - - for file in os.listdir(fake_dir): - assert file == "openapi.json" - assert os.path.getsize(os.path.join(fake_dir, file)) > 0 + with tempfile.TemporaryDirectory() as workdir: + fake_dir = os.path.join(workdir, "fake_dir") + dump_http_services(fake_dir) + assert os.path.exists(fake_dir) - shutil.rmtree(fake_dir) + for file in os.listdir(fake_dir): + assert file == "openapi.json" + assert os.path.getsize(os.path.join(fake_dir, file)) > 0 diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index 0c5fa549e..d9cbff24a 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -18,12 +18,10 @@ from dataclasses import dataclass from unittest import mock import json -import os import signal import tempfile import threading import time -import uuid # Third Party from google.protobuf.descriptor_pool import DescriptorPool @@ -34,6 +32,7 @@ ) import grpc import pytest +import requests import tls_test_tools # First Party @@ -41,15 +40,22 @@ # Local from caikit import get_config -from caikit.core import MODEL_MANAGER from caikit.core.data_model.producer import ProducerId from caikit.interfaces.runtime.data_model import ( + ModelInfoRequest, + ModelInfoResponse, + RuntimeInfoRequest, + RuntimeInfoResponse, TrainingInfoRequest, TrainingJob, - TrainingStatus, TrainingStatusResponse, ) -from caikit.runtime import get_inference_request, get_train_params, get_train_request +from caikit.runtime import ( + get_inference_request, + get_train_params, + get_train_request, + http_server, +) from caikit.runtime.grpc_server import RuntimeGRPCServer from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.protobufs import ( @@ -69,11 +75,12 @@ ) from sample_lib.data_model.sample import OtherTask, SampleTask, StreamingTask from sample_lib.modules import FirstTask -from tests.conftest import random_test_id, temp_config +from tests.conftest import ARM_ARCH, PROTOBUF_VERSION, random_test_id, temp_config from tests.core.helpers import * from tests.fixtures import Fixtures from tests.runtime.conftest import ( ModuleSubproc, + _open_port, register_trained_model, runtime_grpc_test_server, ) @@ -832,7 +839,7 @@ def test_load_model_badmodel_error_response(runtime_grpc_server): modelKey="baz", ) stub.loadModel(load_model_request) - assert context.value.code() == grpc.StatusCode.INTERNAL + assert context.value.code() == grpc.StatusCode.NOT_FOUND def test_unload_model_ok_response(sample_task_model_id, runtime_grpc_server): @@ -953,6 +960,110 @@ def test_runtime_status_ok_response(runtime_grpc_server): assert actual_response.numericRuntimeVersion == 0 +def test_runtime_info_ok_response(runtime_grpc_server): + runtime_info_service: ServicePackage = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + + runtime_info_stub = runtime_info_service.stub_class( + runtime_grpc_server.make_local_channel() + ) + + runtime_request = RuntimeInfoRequest() + runtime_info_response: RuntimeInfoResponse = RuntimeInfoResponse.from_proto( + runtime_info_stub.GetRuntimeInfo(runtime_request.to_proto()) + ) + + assert "caikit" in runtime_info_response.python_packages + # runtime_version not added if not set + assert runtime_info_response.runtime_version == "" + # dependent libraries not added if all packages not set to true + assert "py_to_proto" not in runtime_info_response.python_packages + + +def test_runtime_info_ok_response_all_packages(runtime_grpc_server): + with temp_config( + { + "runtime": { + "version_info": { + "python_packages": { + "all": True, + }, + "runtime_image": "1.2.3", + } + }, + }, + "merge", + ): + runtime_info_service: ServicePackage = ( + ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + ) + + runtime_info_stub = runtime_info_service.stub_class( + runtime_grpc_server.make_local_channel() + ) + + runtime_request = RuntimeInfoRequest() + runtime_info_response: RuntimeInfoResponse = RuntimeInfoResponse.from_proto( + runtime_info_stub.GetRuntimeInfo(runtime_request.to_proto()) + ) + + assert "caikit" in runtime_info_response.python_packages + assert runtime_info_response.runtime_version == "1.2.3" + # dependent libraries versions added + assert "alog" in runtime_info_response.python_packages + assert "py_to_proto" in runtime_info_response.python_packages + + +def test_all_model_info_ok_response(runtime_grpc_server, sample_task_model_id): + info_service: ServicePackage = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + + model_info_stub = info_service.stub_class(runtime_grpc_server.make_local_channel()) + + model_request = ModelInfoRequest() + model_info_response: ModelInfoResponse = ModelInfoResponse.from_proto( + model_info_stub.GetModelsInfo(model_request.to_proto()) + ) + + assert len(model_info_response.models) > 0 + + found_sample_task = False + for model in model_info_response.models: + # Assert name and id exist + assert model.name and model.module_id + # Assert metadata module_name matches expected + if model.name == sample_task_model_id: + assert model.module_metadata.get("name") == "SampleModule" + found_sample_task = True + + assert found_sample_task, "Unable to find sample_task model in models list" + + +def test_single_model_info_ok_response(runtime_grpc_server, sample_task_model_id): + info_service: ServicePackage = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + + model_info_stub = info_service.stub_class(runtime_grpc_server.make_local_channel()) + + model_request = ModelInfoRequest(model_ids=[sample_task_model_id]) + model_info_response: ModelInfoResponse = ModelInfoResponse.from_proto( + model_info_stub.GetModelsInfo(model_request.to_proto()) + ) + + # Assert only one model was returned + assert len(model_info_response.models) == 1 + model = model_info_response.models[0] + # Assert name and id exist + assert model.name and model.module_id + # Assert metadata module_name matches expected + assert model.module_metadata.get("name") == "SampleModule" + + #### Health Probe tests #### def test_grpc_health_probe_ok_response(runtime_grpc_server): """Test health check successful response""" @@ -962,6 +1073,9 @@ def test_grpc_health_probe_ok_response(runtime_grpc_server): assert actual_response.status == 1 +@pytest.mark.skipif( + PROTOBUF_VERSION < 4 and ARM_ARCH, reason="protobuf 3 serialization bug" +) def test_grpc_server_can_render_all_necessary_protobufs( runtime_grpc_server, sample_inference_service, sample_train_service, tmp_path ): @@ -995,8 +1109,8 @@ def test_canceling_model_loads_causes_exceptions(runtime_grpc_server): ) def never_return(*args, **kwargs): - request_received.set() try: + request_received.set() while True: time.sleep(0.01) except Exception as e: @@ -1013,7 +1127,7 @@ def never_return(*args, **kwargs): load_model_future.cancel() # Wait for an exception to be raised in our mock, and assert it was - request_finished.wait(10) + request_finished.wait(2) assert request_finished.is_set() @@ -1059,6 +1173,49 @@ def test_mtls(open_port): stub.Check(health_check_request) +def test_mtls_different_root(open_port): + """Make sure mtls communication works when the CA for the client is not the + same as the CA for the server (including health checks using the server's + CA) + """ + # Server TLS Infra + server_ca_key = tls_test_tools.generate_key()[0] + server_ca_cert = tls_test_tools.generate_ca_cert(server_ca_key) + server_tls_key, server_tls_cert = tls_test_tools.generate_derived_key_cert_pair( + server_ca_key + ) + + # Client TLS Infra + client_ca_key = tls_test_tools.generate_key()[0] + client_ca_cert = tls_test_tools.generate_ca_cert( + client_ca_key, common_name="my.client" + ) + client_tls_key, client_tls_cert = tls_test_tools.generate_derived_key_cert_pair( + client_ca_key, common_name="my.client" + ) + + server_tls_config = TLSConfig( + server=KeyPair(cert=server_tls_cert, key=server_tls_key), + client=KeyPair(cert=client_ca_cert, key=""), + ) + with runtime_grpc_test_server( + open_port, + tls_config_override=server_tls_config, + ) as server: + # Connect using the client's creds + _assert_connection( + _make_secure_channel( + server, server_ca_cert, client_tls_key, client_tls_cert + ) + ) + # Connect using the server's creds + _assert_connection( + _make_secure_channel( + server, server_ca_cert, server_tls_key, server_tls_cert + ) + ) + + @pytest.mark.parametrize( "enabled_services", [(True, False), (False, True), (False, False)], @@ -1244,6 +1401,63 @@ def test_grpc_sever_shutdown_with_model_poll(open_port): assert not server_proc.killed +def test_all_signal_handlers_invoked(open_port): + """Test that a SIGINT successfully shuts down all running servers""" + + # whoops, need 2 ports. Try to find another open one that isn't the one we already have + other_open_port = _open_port(start=open_port + 1) + + with tempfile.TemporaryDirectory() as workdir: + server_proc = ModuleSubproc( + "caikit.runtime", + kill_timeout=30.0, + RUNTIME_GRPC_PORT=str(open_port), + RUNTIME_HTTP_PORT=str(other_open_port), + RUNTIME_LOCAL_MODELS_DIR=workdir, + RUNTIME_LAZY_LOAD_LOCAL_MODELS="true", + RUNTIME_LAZY_LOAD_POLL_PERIOD_SECONDS="0.1", + RUNTIME_METRICS_ENABLED="false", + RUNTIME_GRPC_ENABLED="true", + RUNTIME_HTTP_ENABLED="true", + LOG_LEVEL="info", + ) + with server_proc as proc: + # Wait for the grpc server to be up: + _assert_connection( + grpc.insecure_channel(f"localhost:{open_port}"), max_failures=500 + ) + + # Then wait for the http server as well: + http_failures = 0 + while http_failures < 500: + try: + resp = requests.get( + f"http://localhost:{other_open_port}{http_server.HEALTH_ENDPOINT}", + timeout=0.1, + ) + resp.raise_for_status() + break + except ( + requests.HTTPError, + requests.ConnectionError, + requests.ConnectTimeout, + ): + http_failures += 1 + # tiny sleep because a connection refused won't hit the full `0.1`s timeout + time.sleep(0.001) + + # Signal the server to shut down + proc.send_signal(signal.SIGINT) + + # Make sure the process was not killed + assert not server_proc.killed + # Check the logs (barf) to see if both grpc and http signal handlers called + # communicate returns (stdout, stderr) in bytes + logs = server_proc.proc.communicate()[1].decode("utf-8") + assert "Shutting down gRPC server" in logs + assert "Shutting down http server" in logs + + def test_construct_with_options(open_port, sample_train_service, sample_int_file): """Make sure that the server can be booted with config options""" with temp_config( @@ -1283,6 +1497,25 @@ def test_construct_with_options(open_port, sample_train_service, sample_int_file assert context.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED +def test_grpc_server_socket_listen(): + """Make sure that the server correctly listen on a unix socket""" + with tempfile.TemporaryDirectory() as socket_dir: + with temp_config( + {"runtime": {"grpc": {"unix_socket_path": socket_dir + "/grpc.sock"}}}, + "merge", + ): + with RuntimeGRPCServer(): + stub = model_runtime_pb2_grpc.ModelRuntimeStub( + grpc.insecure_channel(f"unix://{socket_dir}/grpc.sock") + ) + runtime_status_request = model_runtime_pb2.RuntimeStatusRequest() + actual_response = stub.runtimeStatus(runtime_status_request) + assert ( + actual_response.status + == model_runtime_pb2.RuntimeStatusResponse.READY + ) + + # Test implementation details ######################### @dataclass class KeyPair: diff --git a/tests/runtime/test_service_factory.py b/tests/runtime/test_service_factory.py index d36d12af3..563b738e6 100644 --- a/tests/runtime/test_service_factory.py +++ b/tests/runtime/test_service_factory.py @@ -37,7 +37,7 @@ from sample_lib.data_model import SampleInputType, SampleOutputType from sample_lib.data_model.sample import SampleTask from sample_lib.modules import ListModule, OtherModule -from tests.conftest import temp_config +from tests.conftest import ARM_ARCH, PROTOBUF_VERSION, temp_config from tests.core.helpers import MockBackend from tests.data_model_helpers import reset_global_protobuf_registry, temp_dpool from tests.runtime.conftest import sample_inference_service, sample_train_service @@ -55,7 +55,7 @@ def clean_data_model(sample_inference_service, sample_train_service): with reset_global_protobuf_registry(): with temp_dpool( inherit_global=True, - skip_inherit=[".*sampletask.*\.proto"], + skip_inherit=[r".*sampletask.*\.proto"], ) as dpool: yield dpool @@ -320,6 +320,9 @@ def test_override_package(clean_data_model): assert "SampleModule" in str(clean_modules) +@pytest.mark.skipif( + PROTOBUF_VERSION < 4 and ARM_ARCH, reason="protobuf 3 serialization bug" +) def test_override_package_and_domain_with_proto_gen(clean_data_model): """ Test override of both package and domain, to make sure they work together, and diff --git a/tests/runtime/work_management/test_abortable_action.py b/tests/runtime/work_management/test_abortable_action.py deleted file mode 100644 index 744e98571..000000000 --- a/tests/runtime/work_management/test_abortable_action.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Standard -import threading -import time -import unittest - -# Local -from caikit.runtime.types.aborted_exception import AbortedException -from caikit.runtime.work_management.abortable_action import AbortableAction -from caikit.runtime.work_management.rpc_aborter import RpcAborter -from tests.fixtures import Fixtures - - -class TestAbortableAction(unittest.TestCase): - """This test suite tests the abortable action class""" - - def setUp(self): - """This method runs before each test begins to run""" - self.rpc_context = Fixtures.build_context() - self.aborter = RpcAborter(self.rpc_context) - - def test_it_can_run_a_function(self): - expected_result = "test-any-result" - action = AbortableAction(self.aborter, lambda *args, **kwargs: expected_result) - result = action.do() - self.assertEqual(expected_result, result) - - def test_it_raises_if_the_rpc_has_already_terminated(self): - action = AbortableAction(self.aborter, lambda *args, **kwargs: None) - self.rpc_context.cancel() - - with self.assertRaises(AbortedException) as context: - action.do() - - def test_it_raises_if_the_function_raises(self): - expected_exception = ValueError("test-any-error") - - def thrower(): - raise expected_exception - - action = AbortableAction(self.aborter, thrower) - with self.assertRaises(ValueError) as ctx: - action.do() - - self.assertEqual(expected_exception, ctx.exception) - - def test_it_raises_if_the_rpc_is_terminated_mid_function(self): - infinite_function_has_started = threading.Event() - - def infinite_function(): - infinite_function_has_started.set() - while True: - time.sleep(0.1) - - action = AbortableAction(self.aborter, infinite_function) - - def inner_test_thread(): - with self.assertRaises(AbortedException) as context: - action.do() - - thread = threading.Thread(target=inner_test_thread) - thread.start() - infinite_function_has_started.wait() - - self.rpc_context.cancel() - thread.join(5) - self.assertFalse(thread.is_alive()) diff --git a/tests/runtime/work_management/test_abortable_context.py b/tests/runtime/work_management/test_abortable_context.py new file mode 100644 index 000000000..609c3fb24 --- /dev/null +++ b/tests/runtime/work_management/test_abortable_context.py @@ -0,0 +1,167 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from concurrent.futures import Future, ThreadPoolExecutor +import dataclasses +import datetime +import random +import time + +# Third Party +import grpc +import pytest + +# Local +from caikit.runtime.types.aborted_exception import AbortedException +from caikit.runtime.work_management.abortable_context import ( + AbortableContext, + ThreadInterrupter, +) +from caikit.runtime.work_management.rpc_aborter import RpcAborter +from tests.fixtures import Fixtures + +## Helpers ##################################################################### + + +@pytest.fixture(scope="session") +def thread_interrupter(): + interrupter = ThreadInterrupter() + interrupter.start() + + yield interrupter + + interrupter.stop() + + +@pytest.fixture() +def grpc_context() -> grpc.ServicerContext: + return Fixtures.build_context("abortable-context-test") + + +@pytest.fixture() +def rpc_aborter(grpc_context): + return RpcAborter(grpc_context) + + +def wait_for_interrupter_to_run(interrupter, timeout=1): + """Helper to wait until the interrupter's queue is empty. + This should only deadlock if the interrupter's polling thread exits. + """ + start = datetime.datetime.now() + while (datetime.datetime.now() - start).total_seconds() < timeout: + if interrupter._queue.empty(): + return + time.sleep(0.001) + + +## Tests ####################################################################### + + +def test_context_runs_stuff(thread_interrupter, rpc_aborter): + """Just an ordinary context manager here""" + one_plus_one = 0 + with AbortableContext(rpc_aborter, thread_interrupter): + one_plus_one += 2 + + assert one_plus_one == 2 + + +def test_context_can_be_canceled(thread_interrupter, rpc_aborter, grpc_context): + """An AbortedException is raised as soon as the rpc context is canceled""" + result = 0 + with pytest.raises(AbortedException): + with AbortableContext(rpc_aborter, thread_interrupter): + result += 1 + grpc_context.cancel() + assert not thread_interrupter._queue.empty() + wait_for_interrupter_to_run(thread_interrupter) + assert False + + assert result == 1 + + +def test_context_aborts_if_rpc_already_canceled( + thread_interrupter, rpc_aborter, grpc_context +): + """The context will abort if the rpc context was previously canceled""" + grpc_context.cancel() + + with pytest.raises(AbortedException): + with AbortableContext(rpc_aborter, thread_interrupter): + wait_for_interrupter_to_run(thread_interrupter) + assert False + + +def test_exceptions_can_be_raised_in_context(thread_interrupter, rpc_aborter): + """Exceptions work normally""" + + with pytest.raises(ValueError, match="this is a test"): + with AbortableContext(rpc_aborter, thread_interrupter): + raise ValueError("this is a test") + + +def test_many_threads_can_run_in_abortable_context_at_once(thread_interrupter): + """This test tries to replicate a multithreaded situation where many threads can complete + an AbortableContext and many others are aborted. We want to make sure only the contexts that + we canceled are actually aborted- i.e. the interrupter interrupts the correct contexts.""" + + @dataclasses.dataclass + class TestTask: + + context: grpc.ServicerContext + wait_for_cancel: bool + future: Future = None + + def run(self): + """Dummy task that either returns quickly or spins forever waiting to be interrupted""" + aborter = RpcAborter(self.context) + with AbortableContext(aborter=aborter, interrupter=thread_interrupter): + if self.wait_for_cancel: + while True: + time.sleep(0.001) + else: + time.sleep(0.001) + + # Create a bunch of tasks, half of them need to be interrupted + tasks = [] + for i in range(25): + tasks.append( + TestTask( + context=Fixtures.build_context(f"test-task-{i}"), wait_for_cancel=False + ) + ) + for i in range(25): + tasks.append( + TestTask( + context=Fixtures.build_context(f"test-cancel-task-{i}"), + wait_for_cancel=True, + ) + ) + random.shuffle(tasks) + + # Submit them all and cancel the context of the ones that need interrupting + pool = ThreadPoolExecutor(max_workers=50) + for t in tasks: + t.future = pool.submit(t.run) + for t in tasks: + if t.wait_for_cancel: + t.context.cancel() + + # Assert that the ones we canceled throw, and the rest don't + for t in tasks: + if t.wait_for_cancel: + with pytest.raises(AbortedException): + t.future.result() + else: + t.future.result() diff --git a/tests/runtime/work_management/test_call_aborter.py b/tests/runtime/work_management/test_call_aborter.py index f1e002cdd..661799f01 100644 --- a/tests/runtime/work_management/test_call_aborter.py +++ b/tests/runtime/work_management/test_call_aborter.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Standard -import threading import unittest # Local @@ -20,34 +19,50 @@ from tests.fixtures import Fixtures -class TestRpcAborter(unittest.TestCase): - """This test suite tests the call aborter utility""" - - CHANNEL = None - GRPC_THREAD = None - - def test_call_aborter_sets_event(self): - ctx = Fixtures.build_context("call_aborter_event_party") - # Create a new Call aborter - aborter = RpcAborter(ctx) - # Create a new threading event and add it to the call aborter - event = threading.Event() - aborter.add_event(event) - # Cancel the call & wait for the threading event to be set by __rpc_terminated - ctx.cancel() - event.wait() - self.assertTrue(aborter.must_abort()) - - def test_call_aborter_sets_event_added_after_termination(self): - ctx = Fixtures.build_context("call_aborter_event_party") - # Create a new call aborter - aborter = RpcAborter(ctx) - # Cancel the call before creating the threading event and adding to the aborter - ctx.cancel() - event = threading.Event() - aborter.add_event(event) - event.wait() - self.assertTrue(aborter.must_abort()) +class StubAbortableContext: + """Test context, simply sets flag if `abort` was called""" + + def __init__(self): + self.aborted = False + + def abort(self): + self.aborted = True + + +def test_call_aborter_invokes_abortable_context(): + """The whole reason this class exists: + If the grpc context is canceled, the abortable context should be aborted + """ + grpc_ctx = Fixtures.build_context("call_aborter_event_party") + abort_ctx = StubAbortableContext() + + # Create a new Call aborter + aborter = RpcAborter(grpc_ctx) + # Set its abort context + aborter.set_context(abort_ctx) + + assert not abort_ctx.aborted + + # Cancel the call and check that context was aborted + grpc_ctx.cancel() + + assert abort_ctx.aborted + + +def test_call_aborter_invokes_abortable_context_when_grpc_context_is_already_canceled(): + """Edge case: if the grpc context has already been canceled, the abortable context is immediately aborted as well""" + grpc_ctx = Fixtures.build_context("call_aborter_event_party") + abort_ctx = StubAbortableContext() + + # Prematurely cancel grpc context + grpc_ctx.cancel() + + # Create a new Call aborter + aborter = RpcAborter(grpc_ctx) + # Set its abort context + aborter.set_context(abort_ctx) + # And it should immediately abort + assert abort_ctx.aborted if __name__ == "__main__": diff --git a/tox.ini b/tox.ini index 56b094084..2fa5ad525 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,11 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH -commands = pytest --cov=caikit --cov-report=html:coverage-{env_name} --cov-report=xml:coverage-{env_name}.xml --html=durations/{env_name}.html {posargs:tests -m "not examples"} -W error::UserWarning +setenv = + DFTYPE = pandas_all + +commands = pytest --cov=caikit --cov-report=html:coverage-{env_name} --cov-report=xml:coverage-{env_name}.xml --html=durations/{env_name}.html {posargs:tests -m "not (examples or slow)"} -W error::UserWarning +; -W ignore::DeprecationWarning ; Unclear: We probably want to test wheel packaging ; But! tox will fail when this is set and _any_ interpreter is missing @@ -36,12 +40,12 @@ commands = ./scripts/fmt.sh allowlist_externals = ./scripts/fmt.sh [testenv:lint] -description = lint with pylint +description = lint with ruff extras = all dev-fmt dev-test -commands = pylint caikit examples/text-sentiment/text_sentiment examples/text-sentiment/*.py examples/sample_lib/*.py +commands = ruff check caikit examples [testenv:imports] description = enforce internal import rules @@ -65,10 +69,9 @@ commands = twine check dist/* # Ensure compatibility is maintained with protobuf 3.X [testenv:proto3] description = run tests with pytest with coverage +extras = dev-proto3 commands = - pip uninstall grpcio-health-checking grpcio-reflection -y - pip install protobuf==3.19.0 grpcio-health-checking grpcio-reflection --upgrade - pytest --cov=caikit --cov-report=html {posargs:tests -m "not examples"} + pytest --cov=caikit --cov-report=html {posargs:tests -m "not (examples or slow)"} # Ensure tests targeting caikit.core can be run with no optional dependencies [testenv:core]