Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge branch 'caikit:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Xaenalt authored Oct 23, 2023
2 parents 29ace72 + a0fb1fb commit 04bdc46
Show file tree
Hide file tree
Showing 62 changed files with 3,000 additions and 619 deletions.
20 changes: 20 additions & 0 deletions caikit/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ data_streams:
# Base directory where stream source relative paths should be found
file_source_base: null

# Config for data stream source plugins. The keys in the map are names for
# environment scoping and the values are factory blobs for an importable
# factory. Each blob requires a `type` and can optionally have `config` and
# `import_class`.
source_plugins:
inline:
type: JsonData
file:
type: FileData
file_list:
type: ListOfFiles
directory:
type: Directory
s3:
type: S3Files

### Runtime configurations
runtime:
# The runtime library (or libraries) whose models we want to serve using Caikit Runtime. This should
Expand Down Expand Up @@ -146,6 +162,10 @@ runtime:
task_types:
included: []
excluded: []
backwards_compatibility:
enabled: false
current_modules_path: modules.json
prev_modules_path: prev_modules.json

# Configuration for batch inference. This dict contains entries for individual
# model types and can be extended using additional overlay config files to add
Expand Down
57 changes: 54 additions & 3 deletions caikit/core/data_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
from dataclasses import dataclass
from enum import Enum
from io import IOBase
from typing import Any, Dict, List, Optional, Type, Union
import base64
import datetime
Expand Down Expand Up @@ -79,6 +80,10 @@ class _DataBaseMetaClass(type):
if name.startswith("TYPE_") and "INT" in name
]

# Add property to track if a class supports exporting and importing via a
# file operation
supports_file_operations = False

def __new__(mcs, name, bases, attrs):
"""When constructing a new data model class, we set the 'fields' class variable from the
protobufs descriptor and then set the '__slots__' magic class attribute to fields. This
Expand Down Expand Up @@ -301,6 +306,13 @@ def parse_proto_descriptor(mcs, cls):
if current_init is None or current_init is DataBase.__init__:
setattr(cls, "__init__", mcs._make_init(cls.fields))

# Check DataBase for file handlers
setattr(
cls,
"supports_file_operations",
cls.to_file != DataBase.to_file and cls.from_file != DataBase.from_file,
)

@classmethod
def _make_property_getter(mcs, field, oneof_name=None):
"""This helper creates an @property attribute getter for the given field
Expand Down Expand Up @@ -710,7 +722,17 @@ def from_proto(cls, proto):
oneof = cls._fields_to_oneof[field]
contained_class = cls.get_class_for_proto(proto_attr)
contained_obj = contained_class.from_proto(proto_attr)
kwargs[oneof] = getattr(contained_obj, "values")
if hasattr(contained_obj, "values") and (
contained_class.__module__.startswith(
"caikit.core.data_model"
)
or contained_class.__module__.startswith(
"caikit.interfaces.common.data_model"
)
):
kwargs[oneof] = getattr(contained_obj, "values")
else:
kwargs[oneof] = contained_obj
else:
contained_class = cls.get_class_for_proto(proto_attr)
contained_obj = contained_class.from_proto(proto_attr)
Expand Down Expand Up @@ -742,13 +764,14 @@ def from_proto(cls, proto):
return cls(**kwargs)

@classmethod
def from_json(cls, json_str):
def from_json(cls, json_str, ignore_unknown_fields=False):
"""Build a DataBase from a given JSON string. Use google's protobufs.json_format for
deserialization
Args:
json_str (str or dict): A stringified JSON specification/dict of the
data_model
ignore_unknown_fields (bool): If True, ignores unknown JSON fields
Returns:
caikit.core.data_model.DataBase: A DataBase object.
Expand All @@ -763,7 +786,9 @@ def from_json(cls, json_str):
try:
# Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType
parsed_proto = json_format.Parse(
json_str, cls.get_proto_class()(), ignore_unknown_fields=False
json_str,
cls.get_proto_class()(),
ignore_unknown_fields=ignore_unknown_fields,
)

# Use from_proto to return the DataBase object from the parsed proto
Expand All @@ -772,6 +797,19 @@ def from_json(cls, json_str):
except json_format.ParseError as ex:
error("<COR90619980E>", ValueError(ex))

@classmethod
def from_file(cls, file_obj: IOBase):
"""Build a DataBase from a given file-like object.
Args:
file_obj IOBase: A file object that contains some representation
of the dataobject
Returns:
caikit.core.data_model.DataBase: A DataBase object.
"""
raise NotImplementedError(f"from_file not implemented for {cls}")

def to_proto(self):
"""Return a new protobufs populated with the information in this data structure."""
# get the name of the protobufs class
Expand Down Expand Up @@ -948,6 +986,19 @@ def _default_serialization_overrides(obj):

return json.dumps(self.to_dict(), **kwargs)

def to_file(self, file_obj: IOBase) -> Optional["File"]:
"""Export a DataBaseObject into a file-like object `file_obj`. If the DataBase object
has requirements around file name or file type it can return them via
the optional "File" return object
Args:
file_obj IOBase: a file object to be filled
Returns:
file_descriptor: Optional[caikit.interfaces.common.data_mode.File]
"""
raise NotImplementedError(f"to_file not implemented for {self.__class__}")

def __repr__(self):
"""Human-friendly representation."""
return self.to_json(indent=2, ensure_ascii=False)
Expand Down
7 changes: 7 additions & 0 deletions caikit/core/data_model/streams/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,16 @@ def _from_json_array_buffer_generator(cls, json_fh: typing.IO, filename: str = "
)
# For each {} object of the array
try:
item_idx = None
for item_idx, obj in enumerate(ijson.items(json_fh, "item")):
log.debug2("Loading object index %d", item_idx)
yield obj
if item_idx is None:
# Not an array
error(
"<COR79428339E>",
ValueError("Non-array JSON object in `{}`".format(filename)),
)
except ijson.JSONError:
error(
"<COR85596551E>",
Expand Down
66 changes: 1 addition & 65 deletions caikit/core/model_management/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,78 +14,14 @@
"""
Global factories for model management
"""
# Standard
from typing import Optional
import importlib

# First Party
import alog

# Local
from ..exceptions import error_handler
from ..toolkit.factory import Factory, FactoryConstructible
from ..toolkit.factory import ImportableFactory
from .local_model_finder import LocalModelFinder
from .local_model_initializer import LocalModelInitializer
from .local_model_trainer import LocalModelTrainer
from .multi_model_finder import MultiModelFinder

log = alog.use_channel("MMFCTRY")
error = error_handler.get(log)


class ImportableFactory(Factory):
"""An ImportableFactory extends the base Factory to allow the construction
to specify an "import_class" field that will be used to import and register
the implementation class before attempting to initialize it.
"""

IMPORT_CLASS_KEY = "import_class"

def construct(
self,
instance_config: dict,
instance_name: Optional[str] = None,
):
# Look for an import_class and import and register it if found
import_class_val = instance_config.get(self.__class__.IMPORT_CLASS_KEY)
if import_class_val:
error.type_check(
"<COR85108801E>",
str,
**{self.__class__.IMPORT_CLASS_KEY: import_class_val},
)
module_name, class_name = import_class_val.rsplit(".", 1)
try:
imported_module = importlib.import_module(module_name)
except ImportError:
error(
"<COR46837141E>",
ValueError(
"Invalid {}: Module cannot be imported [{}]".format(
self.__class__.IMPORT_CLASS_KEY,
module_name,
)
),
)
try:
imported_class = getattr(imported_module, class_name)
except AttributeError:
error(
"<COR46837142E>",
ValueError(
"Invalid {}: No such class [{}] on module [{}]".format(
self.__class__.IMPORT_CLASS_KEY,
class_name,
module_name,
)
),
)
error.subclass_check("<COR52306423E>", imported_class, FactoryConstructible)

self.register(imported_class)
return super().construct(instance_config, instance_name)


# Model trainer factory. A trainer is responsible for performing the train
# operation against a configured framework connection.
model_trainer_factory = ImportableFactory("ModelTrainer")
Expand Down
28 changes: 24 additions & 4 deletions caikit/core/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# Standard
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import collections
import os
import shutil
Expand All @@ -48,7 +48,7 @@
log = alog.use_channel("MODULE")
error = error_handler.get(log)


# pylint: disable=too-many-public-methods
class ModuleBase(metaclass=_ModuleBaseMeta):
"""Abstract base class from which all modules should inherit."""

Expand Down Expand Up @@ -86,16 +86,36 @@ def set_load_backend(self, load_backend):

@classmethod
def get_inference_signature(
cls, input_streaming: bool, output_streaming: bool
cls,
input_streaming: bool,
output_streaming: bool,
task: Type["caikit.core.TaskBase"] = None,
) -> Optional["caikit.core.signature_parsing.CaikitMethodSignature"]:
"""Returns the inference method signature that is capable of running the module's task
for the given flavors of input and output streaming
"""
for in_streaming, out_streaming, signature in cls._INFERENCE_SIGNATURES:

if task is not None and task in cls._TASK_INFERENCE_SIGNATURES:
signatures = cls._TASK_INFERENCE_SIGNATURES[task]
elif cls._TASK_INFERENCE_SIGNATURES:
signatures = next(iter(cls._TASK_INFERENCE_SIGNATURES.values()))
else:
signatures = []

for in_streaming, out_streaming, signature in signatures:
if in_streaming == input_streaming and out_streaming == output_streaming:
return signature
return None

@classmethod
def get_inference_signatures(
cls, task: Type["caikit.core.TaskBase"]
) -> List[Tuple[bool, bool, "caikit.core.signature_parsing.CaikitMethodSignature"]]:
"""Returns inference method signatures for all supported flavors
of input and output streaming for a given task
"""
return cls._TASK_INFERENCE_SIGNATURES.get(task)

@property
def load_backend(self):
"""Get the backend instance used to load this module. This can be used
Expand Down
Loading

0 comments on commit 04bdc46

Please sign in to comment.