Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get_service api for selected frameworks #4782

Merged
merged 18 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/reference/frameworks/catboost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CatBoost
.. admonition:: About this page

This is an API reference for CatBoost in BentoML. Please refer to
:ref:`CatBoost guides <frameworks/catboost:CatBoost>` for more information about how to use CatBoost
:ref:`CatBoost guides </reference/frameworks/catboost:CatBoost>` for more information about how to use CatBoost
in BentoML.


Expand All @@ -16,3 +16,5 @@ CatBoost
.. autofunction:: bentoml.catboost.load_model

.. autofunction:: bentoml.catboost.get

.. autofunction:: bentoml.catboost.get_service
2 changes: 1 addition & 1 deletion docs/source/reference/frameworks/detectron.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Detectron
.. admonition:: About this page

This is an API reference for Detectron in BentoML. Please refer to
:doc:`Detectron guide </frameworks/detectron>` for more information about how to use
:doc:`Detectron guide </reference/frameworks/detectron>` for more information about how to use
Detectron in BentoML.

.. currentmodule:: bentoml.detectron
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/frameworks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Framework APIs
.. note::

This page contains the API reference for all frameworks. For more information on a specific
framework, please see :doc:`/frameworks/index`
framework, please see :doc:`/reference/frameworks/index`

.. grid:: 1 2 2 2
:gutter: 3
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/frameworks/lightgbm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ LightGBM
.. autofunction:: bentoml.lightgbm.load_model

.. autofunction:: bentoml.lightgbm.get

.. autofunction:: bentoml.lightgbm.get_service
2 changes: 1 addition & 1 deletion docs/source/reference/frameworks/mlflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ MLflow

.. autofunction:: bentoml.mlflow.get

.. autofunction:: bentoml.mlflow.get_mlflow_model
.. autofunction:: bentoml.mlflow.get_service
2 changes: 1 addition & 1 deletion docs/source/reference/frameworks/transformers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Transformers
.. admonition:: About this page

This is an API reference for 🤗 Transformers in BentoML. Please refer to
:doc:`Transformers guide </frameworks/transformers>` for more information about how to use
:doc:`Transformers guide </reference/frameworks/transformers>` for more information about how to use
Hugging Face Transformers in BentoML.

.. currentmodule:: bentoml.transformers
Expand Down
4 changes: 3 additions & 1 deletion docs/source/reference/frameworks/xgboost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ XGBoost
.. admonition:: About this page

This is an API reference for XGBoost in BentoML. Please refer to the
:doc:`XGBoost guide </frameworks/xgboost>` for more information about how to use XGBoost
:doc:`XGBoost guide </reference/frameworks/xgboost>` for more information about how to use XGBoost
in BentoML.


Expand All @@ -22,3 +22,5 @@ XGBoost
.. autofunction:: bentoml.xgboost.load_model

.. autofunction:: bentoml.xgboost.get

.. autofunction:: bentoml.xgboost.get_service
80 changes: 32 additions & 48 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,33 @@

PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11"]

FRAMEWORK_DEPENDENCIES = {
"catboost": ["catboost"],
"diffusers": ["diffusers", "transformers", "tokenizer"],
"easyocr": ["easyocr"],
"fastai": ["fastai"],
"flax": [
"tensorflow~=2.13.1",
"pydantic<2",
"flax; platform_system!='Windows'",
"jax[cpu]; platform_system!='Windows'",
"jaxlib; platform_system!='Windows'",
"chex; platform_system!='Windows'",
],
"keras": ["keras"],
"lightgbm": ["lightgbm"],
"onnx": ["onnx", "onnxruntime", "skl2onnx"],
"picklable_model": [],
"pytorch": [],
"pytorch_lightning": ["lightning"],
"sklearn": ["scikit-learn"],
"tensorflow": ["tensorflow~=2.13.1", "pydantic<2"],
"torchscript": [],
"xgboost": ["xgboost"],
"detectron": ["detectron2"],
"transformers": ["transformers", "tokenizer"],
}


@nox.session(python=PYTHON_VERSIONS, name="unit")
def run_unittest(session: nox.Session):
Expand All @@ -22,28 +49,7 @@ def run_unittest(session: nox.Session):


@nox.session(name="framework-integration")
@nox.parametrize(
"framework",
[
"catboost",
"diffusers",
"easyocr",
"fastai",
"flax",
"keras",
"lightgbm",
"onnx",
"picklable_model",
"pytorch",
"pytorch_lightning",
"sklearn",
"tensorflow",
"torchscript",
"xgboost",
"detectron",
"transformers",
],
)
@nox.parametrize("framework", list(FRAMEWORK_DEPENDENCIES))
def run_framework_integration_test(session: nox.Session, framework: str):
session.run("pdm", "sync", "-G", "testing", external=True)
session.install(
Expand All @@ -53,32 +59,10 @@ def run_framework_integration_test(session: nox.Session, framework: str):
"-i",
"https://download.pytorch.org/whl/cpu",
)
session.install(
"catboost",
"lightgbm",
"mlflow",
"fastai",
"xgboost",
"scikit-learn",
"easyocr",
"datasets",
# ONNX dependencies
"onnx",
"onnxruntime",
"skl2onnx",
# tensorflow dependencies
"tensorflow~=2.13.1",
# torch-related dependencies
"lightning",
# huggingface dependencies
"transformers",
"tokenizer",
"diffusers",
"flax; platform_system!='Windows'",
"jax[cpu]; platform_system!='Windows'",
"jaxlib; platform_system!='Windows'",
"chex; platform_system!='Windows'",
)
deps = FRAMEWORK_DEPENDENCIES[framework]
if deps:
session.install(*deps)
session.run("pdm", "list", "--tree")
session.run(
*TEST_ARGS,
"tests/integration/frameworks/test_frameworks.py",
Expand Down
35 changes: 35 additions & 0 deletions src/_bentoml_impl/frameworks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import importlib.util
from importlib.abc import MetaPathFinder
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from importlib.machinery import ModuleSpec
from types import ModuleType
from typing import Sequence


class FrameworkImporter(MetaPathFinder):
def find_spec(
self, fullname: str, path: Sequence[str] | None, target: ModuleType | None = ...
) -> ModuleSpec | None:
if not fullname.startswith("bentoml."):
return None
framework = fullname.split(".")[1]
if "." in framework:
return None
spec = importlib.util.find_spec(f"_bentoml_impl.frameworks.{framework}")
if spec is None:
spec = importlib.util.find_spec(f"bentoml._internal.frameworks.{framework}")
return spec

@classmethod
def install(cls) -> None:
import sys

for finder in sys.meta_path:
if isinstance(finder, cls):
return

sys.meta_path.append(cls())
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@

import attr
import numpy as np
from typing_extensions import deprecated

import bentoml
from bentoml import Tag
from bentoml._internal.models.model import ModelContext
from bentoml._internal.utils.pkg import get_pkg_version
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import InvalidArgument
from bentoml.exceptions import MissingDependencyException
from bentoml.exceptions import NotFound
from bentoml.models import ModelOptions

from ..models.model import ModelContext
from ..utils.pkg import get_pkg_version
from bentoml.models import ModelOptions as BaseModelOptions
from bentoml.models import get as get

if TYPE_CHECKING:
from typing_extensions import Unpack

from _bentoml_sdk import Service
from _bentoml_sdk import ServiceConfig
from bentoml.types import ModelSignature
from bentoml.types import ModelSignatureDict

Expand All @@ -37,34 +42,7 @@
DEFAULT_MODEL_TRAINING_CLASS_NAME = "CatBoost"
API_VERSION = "v1"

logger = logging.getLogger(__name__)


def get(tag_like: str | Tag) -> bentoml.Model:
"""
Get the BentoML model with the given tag.

Args:
tag_like (``str`` ``|`` :obj:`~bentoml.Tag`):
The tag of the model to retrieve from the model store.

Returns:
:obj:`~bentoml.Model`: A BentoML :obj:`~bentoml.Model` with the matching tag.

Example:

.. code-block:: python

import bentoml
# target model must be from the BentoML model store
model = bentoml.catboost.get("my_catboost_model")
"""
model = bentoml.models.get(tag_like)
if model.info.module not in (MODULE_NAME, __name__):
raise NotFound(
f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}."
)
return model
logger = logging.getLogger(MODULE_NAME)


def load_model(bento_model: str | Tag | bentoml.Model) -> cb.CatBoost:
Expand All @@ -88,7 +66,7 @@ def load_model(bento_model: str | Tag | bentoml.Model) -> cb.CatBoost:
booster = bentoml.catboost.load_model("my_catboost_model")
""" # noqa: LN001
if not isinstance(bento_model, bentoml.Model):
bento_model = get(bento_model)
bento_model = bentoml.models.get(bento_model)

if bento_model.info.module not in (MODULE_NAME, __name__):
raise NotFound(
Expand All @@ -106,7 +84,7 @@ def load_model(bento_model: str | Tag | bentoml.Model) -> cb.CatBoost:


@attr.define
class CatBoostOptions(ModelOptions):
class ModelOptions(BaseModelOptions):
training_class_name: str = attr.field(factory=str)


Expand Down Expand Up @@ -200,7 +178,7 @@ def save_model(
name,
)

options = CatBoostOptions(
options = ModelOptions(
training_class_name=model.__class__.__name__,
)

Expand All @@ -221,6 +199,7 @@ def save_model(
return bento_model


@deprecated("`get_runnable` is a legacy API, use `get_service` instead.")
def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]:
"""
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
Expand Down Expand Up @@ -281,3 +260,51 @@ def _run(self: CatBoostRunnable, input_data: t.Any) -> t.Any:
add_runnable_method(method_name, options)

return CatBoostRunnable


def get_service(model_name: str, **config: Unpack[ServiceConfig]) -> Service[t.Any]:
"""
Get a BentoML service for the catboost model given by name.

Args:
model_name (``str``):
The name of the model to get the service for.
**config (``Unpack[ServiceConfig]``):
Configuration options for the service.
Returns:
A BentoML service instance that wraps the CatBoost model.
Example:

.. code-block:: python

import bentoml

service = bentoml.catboost.get_service("my_catboost_model")
"""

@bentoml.service(**config)
class CatBoostService:
bento_model = bentoml.models.get(model_name)

def __init__(self) -> None:
self.model = load_model(self.bento_model)
self.predict_params = {"task_type": "CPU"}

# check for resources
available_gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
if available_gpus not in ("", "-1"):
self.predict_params["task_type"] = "GPU"
else:
nthreads = os.getenv("OMP_NUM_THREADS")
if nthreads is not None and nthreads != "":
nthreads = max(int(nthreads), 1)
else:
nthreads = -1
self.predict_params["thread_count"] = nthreads

@bentoml.api
def predict(self, input_data: np.ndarray) -> np.ndarray:
rv = self.model.predict(data=cb.Pool(input_data), **self.predict_params)
return np.asarray(rv) # type: ignore (incomplete np types)

return CatBoostService
Loading
Loading