Skip to content

Commit

Permalink
Merge pull request #425 from Jooho/odh_master
Browse files Browse the repository at this point in the history
[20241101] ODH Master Sync
Jooho authored Nov 1, 2024

Verified

This commit was signed with the committer’s verified signature.
xiaofei-du Xiaofei Du
2 parents 53566c7 + ee7a6c5 commit cd93dda
Showing 58 changed files with 3,388 additions and 1,035 deletions.
1 change: 1 addition & 0 deletions .github/workflows/e2e-test.yml
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ env:
PMML_IMG: "pmmlserver"
PADDLE_IMG: "paddleserver"
CUSTOM_MODEL_GRPC_IMG: "custom-model-grpc"
CUSTOM_MODEL_GRPC_IMG_TAG: "kserve/custom-model-grpc:${{ github.sha }}"
HUGGINGFACE_IMG: "huggingfaceserver"
# Explainer images
ART_IMG: "art-explainer"
9 changes: 6 additions & 3 deletions .github/workflows/scheduled-image-scan.yml
Original file line number Diff line number Diff line change
@@ -43,13 +43,14 @@ jobs:
--sarif-file-output=./application/${{ matrix.image.name }}/docker.snyk.sarif
sarif: false

# Replace any "undefined" security severity values with 0. The undefined value is used in the case
# Replace any "undefined" or "null" security severity values with 0. The undefined value is used in the case
# of license-related findings, which do not indicate a security vulnerability.
# See https://github.com/github/codeql-action/issues/2187 for more context.
# This can be removed once https://github.com/snyk/cli/pull/5409 is merged.
- name: Replace security-severity undefined for license-related findings
run: |
sudo sed -i 's/"security-severity": "undefined"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
sudo sed -i 's/"security-severity": "null"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
- name: Upload sarif file to Github Code Scanning
if: always()
@@ -88,13 +89,14 @@ jobs:
--sarif-file-output=./application/${{ matrix.image.name }}/docker.snyk.sarif
sarif: false

# Replace any "undefined" security severity values with 0. The undefined value is used in the case
# Replace any "undefined" or "null" security severity values with 0. The undefined value is used in the case
# of license-related findings, which do not indicate a security vulnerability.
# See https://github.com/github/codeql-action/issues/2187 for more context.
# This can be removed once https://github.com/snyk/cli/pull/5409 is merged.
- name: Replace security-severity undefined for license-related findings
run: |
sudo sed -i 's/"security-severity": "undefined"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
sudo sed -i 's/"security-severity": "null"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
- name: Upload sarif file to Github Code Scanning
if: always()
@@ -129,13 +131,14 @@ jobs:
--sarif-file-output=./application/${{ matrix.image.name }}/docker.snyk.sarif
sarif: false

# Replace any "undefined" security severity values with 0. The undefined value is used in the case
# Replace any "undefined" or "null" security severity values with 0. The undefined value is used in the case
# of license-related findings, which do not indicate a security vulnerability.
# See https://github.com/github/codeql-action/issues/2187 for more context.
# This can be removed once https://github.com/snyk/cli/pull/5409 is merged.
- name: Replace security-severity undefined for license-related findings
run: |
sudo sed -i 's/"security-severity": "undefined"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
sudo sed -i 's/"security-severity": "null"/"security-severity": "0"/g' ./application/${{ matrix.image.name }}/docker.snyk.sarif
- name: Upload sarif file to Github Code Scanning
if: always()
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
controller-gen.kubebuilder.io/version: v0.16.2
name: localmodelnodes.serving.kserve.io
spec:
group: serving.kserve.io
names:
kind: LocalModelNode
listKind: LocalModelNodeList
plural: localmodelnodes
singular: localmodelnode
scope: Cluster
versions:
- name: v1alpha1
schema:
openAPIV3Schema:
properties:
apiVersion:
type: string
kind:
type: string
metadata:
type: object
spec:
type: object
x-kubernetes-map-type: atomic
x-kubernetes-preserve-unknown-fields: true
status:
type: object
x-kubernetes-map-type: atomic
x-kubernetes-preserve-unknown-fields: true
type: object
served: true
storage: true
subresources:
status: {}
61 changes: 61 additions & 0 deletions config/crd/full/serving.kserve.io_localmodelnodes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
---
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
controller-gen.kubebuilder.io/version: v0.16.2
name: localmodelnodes.serving.kserve.io
spec:
group: serving.kserve.io
names:
kind: LocalModelNode
listKind: LocalModelNodeList
plural: localmodelnodes
singular: localmodelnode
scope: Cluster
versions:
- name: v1alpha1
schema:
openAPIV3Schema:
properties:
apiVersion:
type: string
kind:
type: string
metadata:
type: object
spec:
properties:
localModels:
items:
properties:
modelName:
type: string
sourceModelUri:
type: string
required:
- modelName
- sourceModelUri
type: object
type: array
required:
- localModels
type: object
status:
properties:
modelStatus:
additionalProperties:
enum:
- ""
- ModelDownloadPending
- ModelDownloading
- ModelDownloaded
- ModelDownloadError
type: string
type: object
type: object
type: object
served: true
storage: true
subresources:
status: {}
38 changes: 38 additions & 0 deletions config/crd/minimal/serving.kserve.io_localmodelnodes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
controller-gen.kubebuilder.io/version: v0.16.2
name: localmodelnodes.serving.kserve.io
spec:
group: serving.kserve.io
names:
kind: LocalModelNode
listKind: LocalModelNodeList
plural: localmodelnodes
singular: localmodelnode
scope: Cluster
versions:
- name: v1alpha1
schema:
openAPIV3Schema:
properties:
apiVersion:
type: string
kind:
type: string
metadata:
type: object
spec:
type: object
x-kubernetes-map-type: atomic
x-kubernetes-preserve-unknown-fields: true
status:
type: object
x-kubernetes-map-type: atomic
x-kubernetes-preserve-unknown-fields: true
type: object
served: true
storage: true
subresources:
status: {}
8 changes: 4 additions & 4 deletions hack/quick_install.sh
Original file line number Diff line number Diff line change
@@ -17,11 +17,11 @@ Help() {
echo
}

export ISTIO_VERSION=1.20.4
export KNATIVE_OPERATOR_VERSION=v1.14.5
export KNATIVE_SERVING_VERSION=1.13.1
export ISTIO_VERSION=1.23.2
export KNATIVE_OPERATOR_VERSION=v1.15.7
export KNATIVE_SERVING_VERSION=1.15.2
export KSERVE_VERSION=v0.14.0
export CERT_MANAGER_VERSION=v1.15.1
export CERT_MANAGER_VERSION=v1.16.1
SCRIPT_DIR="$(dirname -- "${BASH_SOURCE[0]}")"
export SCRIPT_DIR

1 change: 1 addition & 0 deletions hack/violation_exceptions.list
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,BuiltInAdapter,Env
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,InferenceGraphList,Items
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,InferenceRouter,Steps
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,LocalModelNodeSpec,LocalModels
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,ServingRuntimePodSpec,Containers
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,ServingRuntimePodSpec,ImagePullSecrets
API rule violation: list_type_missing,github.com/kserve/kserve/pkg/apis/serving/v1alpha1,ServingRuntimePodSpec,Tolerations
34 changes: 34 additions & 0 deletions pkg/apis/serving/v1alpha1/local_model_node_status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Copyright 2024 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package v1alpha1

type LocalModelNodeStatus struct {
// Status of each local model
ModelStatus map[string]ModelStatus `json:"modelStatus,omitempty"`
}

// ModelStatus enum
// +kubebuilder:validation:Enum="";ModelDownloadPending;ModelDownloading;ModelDownloaded;ModelDownloadError
type ModelStatus string

// ModelStatus Enum values
const (
ModelDownloadPending ModelStatus = "ModelDownloadPending"
ModelDownloading ModelStatus = "ModelDownloading"
ModelDownloaded ModelStatus = "ModelDownloaded"
ModelDownloadError ModelStatus = "ModelDownloadError"
)
59 changes: 59 additions & 0 deletions pkg/apis/serving/v1alpha1/local_model_node_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
Copyright 2024 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package v1alpha1

import metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

type LocalModelInfo struct {
// Original StorageUri
SourceModelUri string `json:"sourceModelUri" validate:"required"`
// Model name. Used as the subdirectory name to store this model on local file system
ModelName string `json:"modelName" validate:"required"`
}

// +k8s:openapi-gen=true
type LocalModelNodeSpec struct {
// List of model source URI and their names
LocalModels []LocalModelInfo `json:"localModels" validate:"required"`
}

// +k8s:openapi-gen=true
// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
// +genclient
// +kubebuilder:object:root=true
// +kubebuilder:subresource:status
// +kubebuilder:resource:scope="Cluster"
type LocalModelNode struct {
metav1.TypeMeta `json:",inline"`
metav1.ObjectMeta `json:"metadata,omitempty"`

Spec LocalModelNodeSpec `json:"spec,omitempty"`
Status LocalModelNodeStatus `json:"status,omitempty"`
}

// +k8s:openapi-gen=true
// +kubebuilder:object:root=true
// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
type LocalModelNodeList struct {
metav1.TypeMeta `json:",inline"`
metav1.ListMeta `json:"metadata,omitempty"`
Items []LocalModelNode `json:"items" validate:"required"`
}

func init() {
SchemeBuilder.Register(&LocalModelNode{}, &LocalModelNodeList{})
}
116 changes: 116 additions & 0 deletions pkg/apis/serving/v1alpha1/zz_generated.deepcopy.go
2 changes: 2 additions & 0 deletions pkg/client/informers/externalversions/generic.go
8 changes: 8 additions & 0 deletions pkg/client/listers/serving/v1alpha1/expansion_generated.go
99 changes: 99 additions & 0 deletions pkg/client/listers/serving/v1alpha1/localmodelnode.go
15 changes: 11 additions & 4 deletions pkg/credentials/hf/hf_secret.go
Original file line number Diff line number Diff line change
@@ -22,16 +22,23 @@ import (

const (
HFTokenKey = "HF_TOKEN"
HFTransfer = "HF_HUB_ENABLE_HF_TRANSFER"
)

func BuildSecretEnvs(secret *v1.Secret) []v1.EnvVar {
envs := make([]v1.EnvVar, 0)

if token, ok := secret.Data[HFTokenKey]; ok {
envs = append(envs, v1.EnvVar{
Name: HFTokenKey,
Value: string(token),
})
envs = append(envs, []v1.EnvVar{
{
Name: HFTokenKey,
Value: string(token),
},
{
Name: HFTransfer,
Value: "1",
},
}...)
}
return envs
}
126 changes: 126 additions & 0 deletions pkg/openapi/openapi_generated.go
68 changes: 68 additions & 0 deletions pkg/openapi/swagger.json
Original file line number Diff line number Diff line change
@@ -435,6 +435,31 @@
}
}
},
"v1alpha1.LocalModelNode": {
"type": "object",
"properties": {
"apiVersion": {
"description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources",
"type": "string"
},
"kind": {
"description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds",
"type": "string"
},
"metadata": {
"default": {},
"$ref": "#/definitions/v1.ObjectMeta"
},
"spec": {
"default": {},
"$ref": "#/definitions/v1alpha1.LocalModelNodeSpec"
},
"status": {
"default": {},
"$ref": "#/definitions/v1alpha1.LocalModelNodeStatus"
}
}
},
"v1alpha1.LocalModelNodeGroup": {
"type": "object",
"properties": {
@@ -512,6 +537,49 @@
}
}
},
"v1alpha1.LocalModelNodeList": {
"type": "object",
"required": [
"items"
],
"properties": {
"apiVersion": {
"description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources",
"type": "string"
},
"items": {
"type": "array",
"items": {
"default": {},
"$ref": "#/definitions/v1alpha1.LocalModelNode"
}
},
"kind": {
"description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds",
"type": "string"
},
"metadata": {
"default": {},
"$ref": "#/definitions/v1.ListMeta"
}
}
},
"v1alpha1.LocalModelNodeSpec": {
"type": "object",
"required": [
"localModels"
],
"properties": {
"localModels": {
"description": "List of model source URI and their names",
"type": "array",
"items": {
"default": {},
"$ref": "#/definitions/v1alpha1.LocalModelInfo"
}
}
}
},
"v1alpha1.ModelSpec": {
"description": "ModelSpec describes a TrainedModel",
"type": "object",
2 changes: 1 addition & 1 deletion python/huggingface_server.Dockerfile
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ ARG POETRY_HOME=/opt/poetry
ARG POETRY_VERSION=1.8.3

# Install vllm
ARG VLLM_VERSION=0.6.1.post2
ARG VLLM_VERSION=0.6.3

RUN apt-get update -y && apt-get install gcc python3.10-venv python3-dev -y && apt-get clean && \
rm -rf /var/lib/apt/lists/*
Original file line number Diff line number Diff line change
@@ -384,7 +384,9 @@ def build_generation_config(
return GenerationConfig(**kwargs)

def apply_chat_template(
self, messages: Iterable[ChatCompletionRequestMessage]
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
@@ -394,6 +396,7 @@ def apply_chat_template(
str,
self._tokenizer.apply_chat_template(
[m.model_dump() for m in messages],
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
),
Original file line number Diff line number Diff line change
@@ -359,9 +359,13 @@ def request_output_to_completion_response(
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
chat_template: Optional[str] = None,
):
return self.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, add_generation_prompt=True
conversation=messages,
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
)

async def _post_init(self):
Original file line number Diff line number Diff line change
@@ -69,12 +69,15 @@ async def healthy(self) -> bool:
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
"""
return ChatPrompt(
prompt=self.openai_serving_completion.apply_chat_template(messages)
prompt=self.openai_serving_completion.apply_chat_template(
messages, chat_template
)
)

async def create_completion(
2,053 changes: 1,088 additions & 965 deletions python/huggingfaceserver/poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions python/huggingfaceserver/pyproject.toml
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
kserve = { path = "../kserve", extras = ["storage"], develop = true }
transformers = "~4.43.3"
transformers = ">=4.45.0"
accelerate = "~0.32.0"
torch = "~2.4.0"
vllm = { version = "^0.6.1.post2", optional = true }
setuptools = {version = "^70.0.0", python = "3.12"} # setuptools is not part of python 3.12
vllm = { version = "^0.6.3", optional = true }
setuptools = {version = ">=70.0.0", python = "3.12"} # setuptools is not part of python 3.12

[tool.poetry.extras]
vllm = [
10 changes: 8 additions & 2 deletions python/huggingfaceserver/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -389,6 +389,9 @@ async def test_bloom_chat_completion(bloom_model: HuggingfaceGenerativeModel):
messages=messages,
stream=False,
max_tokens=20,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
@@ -416,6 +419,9 @@ async def test_bloom_chat_completion_streaming(bloom_model: HuggingfaceGenerativ
messages=messages,
stream=True,
max_tokens=20,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
@@ -498,6 +504,6 @@ async def test_input_padding_with_pad_token_not_specified(
response = await openai_gpt_model.create_completion(request)
assert (
response.choices[0].text
== "west, and the sun sets in the west. \n the sun rises in the"
== "west , and the sun sets in the west . \n the sun rises in the"
)
assert "a member of the royal family." in response.choices[1].text
assert "a member of the royal family ." in response.choices[1].text
75 changes: 74 additions & 1 deletion python/huggingfaceserver/tests/test_vllm_model.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,11 @@
from huggingfaceserver.vllm.vllm_completions import OpenAIServingCompletion
from huggingfaceserver.vllm.vllm_model import VLLMModel
from kserve.logging import logger
from kserve.protocol.rest.openai import ChatCompletionRequest, CompletionRequest
from kserve.protocol.rest.openai import (
ChatCompletionRequest,
CompletionRequest,
ChatPrompt,
)
from kserve.protocol.rest.openai.errors import OpenAIError
from kserve.protocol.rest.openai.types import (
CreateChatCompletionRequest,
@@ -98,6 +102,54 @@ def mock_load(self) -> bool:
mp.undo()


def compare_chatprompt_to_expected(actual, expected, fields_to_compare=None) -> bool:
if fields_to_compare is None:
fields_to_compare = [
"response_role",
"prompt",
]
for field in fields_to_compare:
if not getattr(actual, field) == getattr(expected, field):
logger.error(
"expected: %s\n got: %s",
getattr(expected, field),
getattr(actual, field),
)
return False
return True


@pytest.mark.asyncio()
class TestChatTemplate:
async def test_vllm_chat_completion_tokenization_facebook_opt_model(
self, vllm_opt_model
):
opt_model, _ = vllm_opt_model

messages = [
{
"role": "system",
"content": "You are a friendly chatbot who always responds in the style of a pirate",
},
{
"role": "user",
"content": "How many helicopters can a human eat in one sitting?",
},
]
chat_template = (
"{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}"
)
response = opt_model.apply_chat_template(messages, chat_template)

expected = ChatPrompt(
response_role="assistant",
prompt="You are a friendly chatbot who always responds in the style of a pirate</s>How many helicopters can a human eat in one sitting?</s>",
)
assert compare_chatprompt_to_expected(response, expected) is True


def compare_response_to_expected(actual, expected, fields_to_compare=None) -> bool:
if fields_to_compare is None:
fields_to_compare = [
@@ -160,6 +212,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=False,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await opt_model.create_chat_completion(request)
@@ -216,6 +271,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=False,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
@@ -275,6 +333,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=True,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
@@ -325,6 +386,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
max_tokens=10,
log_probs=True,
top_logprobs=2,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
@@ -635,6 +699,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
max_tokens=10,
log_probs=True,
top_logprobs=2,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
@@ -890,6 +957,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=True,
max_tokens=2048,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
@@ -926,6 +996,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
stream=False,
max_tokens=10,
logit_bias={"1527": 50, "27449": 100},
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
14 changes: 14 additions & 0 deletions python/kserve/docs/V1alpha1LocalModelNode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# V1alpha1LocalModelNode

## Properties
Name | Type | Description | Notes
------------ | ------------- | ------------- | -------------
**api_version** | **str** | APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources | [optional]
**kind** | **str** | Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds | [optional]
**metadata** | [**V1ObjectMeta**](https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ObjectMeta.md) | | [optional]
**spec** | [**V1alpha1LocalModelNodeSpec**](V1alpha1LocalModelNodeSpec.md) | | [optional]
**status** | [**V1alpha1LocalModelNodeStatus**](V1alpha1LocalModelNodeStatus.md) | | [optional]

[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)


13 changes: 13 additions & 0 deletions python/kserve/docs/V1alpha1LocalModelNodeList.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# V1alpha1LocalModelNodeList

## Properties
Name | Type | Description | Notes
------------ | ------------- | ------------- | -------------
**api_version** | **str** | APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources | [optional]
**items** | [**list[V1alpha1LocalModelNode]**](V1alpha1LocalModelNode.md) | |
**kind** | **str** | Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds | [optional]
**metadata** | [**V1ListMeta**](https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ListMeta.md) | | [optional]

[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)


10 changes: 10 additions & 0 deletions python/kserve/docs/V1alpha1LocalModelNodeSpec.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# V1alpha1LocalModelNodeSpec

## Properties
Name | Type | Description | Notes
------------ | ------------- | ------------- | -------------
**local_models** | [**list[V1alpha1LocalModelInfo]**](V1alpha1LocalModelInfo.md) | List of model source URI and their names |

[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)


3 changes: 3 additions & 0 deletions python/kserve/kserve/models/__init__.py
Original file line number Diff line number Diff line change
@@ -43,9 +43,12 @@
from kserve.models.v1alpha1_inference_router import V1alpha1InferenceRouter
from kserve.models.v1alpha1_inference_step import V1alpha1InferenceStep
from kserve.models.v1alpha1_inference_target import V1alpha1InferenceTarget
from kserve.models.v1alpha1_local_model_node import V1alpha1LocalModelNode
from kserve.models.v1alpha1_local_model_node_group import V1alpha1LocalModelNodeGroup
from kserve.models.v1alpha1_local_model_node_group_list import V1alpha1LocalModelNodeGroupList
from kserve.models.v1alpha1_local_model_node_group_spec import V1alpha1LocalModelNodeGroupSpec
from kserve.models.v1alpha1_local_model_node_list import V1alpha1LocalModelNodeList
from kserve.models.v1alpha1_local_model_node_spec import V1alpha1LocalModelNodeSpec
from kserve.models.v1alpha1_model_spec import V1alpha1ModelSpec
from kserve.models.v1alpha1_serving_runtime import V1alpha1ServingRuntime
from kserve.models.v1alpha1_serving_runtime_list import V1alpha1ServingRuntimeList
242 changes: 242 additions & 0 deletions python/kserve/kserve/models/v1alpha1_local_model_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


import pprint
import re # noqa: F401

import six

from kserve.configuration import Configuration


class V1alpha1LocalModelNode(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""

"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'api_version': 'str',
'kind': 'str',
'metadata': 'V1ObjectMeta',
'spec': 'V1alpha1LocalModelNodeSpec',
'status': 'V1alpha1LocalModelNodeStatus'
}

attribute_map = {
'api_version': 'apiVersion',
'kind': 'kind',
'metadata': 'metadata',
'spec': 'spec',
'status': 'status'
}

def __init__(self, api_version=None, kind=None, metadata=None, spec=None, status=None, local_vars_configuration=None): # noqa: E501
"""V1alpha1LocalModelNode - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration

self._api_version = None
self._kind = None
self._metadata = None
self._spec = None
self._status = None
self.discriminator = None

if api_version is not None:
self.api_version = api_version
if kind is not None:
self.kind = kind
if metadata is not None:
self.metadata = metadata
if spec is not None:
self.spec = spec
if status is not None:
self.status = status

@property
def api_version(self):
"""Gets the api_version of this V1alpha1LocalModelNode. # noqa: E501
APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources # noqa: E501
:return: The api_version of this V1alpha1LocalModelNode. # noqa: E501
:rtype: str
"""
return self._api_version

@api_version.setter
def api_version(self, api_version):
"""Sets the api_version of this V1alpha1LocalModelNode.
APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources # noqa: E501
:param api_version: The api_version of this V1alpha1LocalModelNode. # noqa: E501
:type: str
"""

self._api_version = api_version

@property
def kind(self):
"""Gets the kind of this V1alpha1LocalModelNode. # noqa: E501
Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds # noqa: E501
:return: The kind of this V1alpha1LocalModelNode. # noqa: E501
:rtype: str
"""
return self._kind

@kind.setter
def kind(self, kind):
"""Sets the kind of this V1alpha1LocalModelNode.
Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds # noqa: E501
:param kind: The kind of this V1alpha1LocalModelNode. # noqa: E501
:type: str
"""

self._kind = kind

@property
def metadata(self):
"""Gets the metadata of this V1alpha1LocalModelNode. # noqa: E501
:return: The metadata of this V1alpha1LocalModelNode. # noqa: E501
:rtype: V1ObjectMeta
"""
return self._metadata

@metadata.setter
def metadata(self, metadata):
"""Sets the metadata of this V1alpha1LocalModelNode.
:param metadata: The metadata of this V1alpha1LocalModelNode. # noqa: E501
:type: V1ObjectMeta
"""

self._metadata = metadata

@property
def spec(self):
"""Gets the spec of this V1alpha1LocalModelNode. # noqa: E501
:return: The spec of this V1alpha1LocalModelNode. # noqa: E501
:rtype: V1alpha1LocalModelNodeSpec
"""
return self._spec

@spec.setter
def spec(self, spec):
"""Sets the spec of this V1alpha1LocalModelNode.
:param spec: The spec of this V1alpha1LocalModelNode. # noqa: E501
:type: V1alpha1LocalModelNodeSpec
"""

self._spec = spec

@property
def status(self):
"""Gets the status of this V1alpha1LocalModelNode. # noqa: E501
:return: The status of this V1alpha1LocalModelNode. # noqa: E501
:rtype: V1alpha1LocalModelNodeStatus
"""
return self._status

@status.setter
def status(self, status):
"""Sets the status of this V1alpha1LocalModelNode.
:param status: The status of this V1alpha1LocalModelNode. # noqa: E501
:type: V1alpha1LocalModelNodeStatus
"""

self._status = status

def to_dict(self):
"""Returns the model properties as a dict"""
result = {}

for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value

return result

def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())

def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()

def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, V1alpha1LocalModelNode):
return False

return self.to_dict() == other.to_dict()

def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, V1alpha1LocalModelNode):
return True

return self.to_dict() != other.to_dict()
217 changes: 217 additions & 0 deletions python/kserve/kserve/models/v1alpha1_local_model_node_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


import pprint
import re # noqa: F401

import six

from kserve.configuration import Configuration


class V1alpha1LocalModelNodeList(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""

"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'api_version': 'str',
'items': 'list[V1alpha1LocalModelNode]',
'kind': 'str',
'metadata': 'V1ListMeta'
}

attribute_map = {
'api_version': 'apiVersion',
'items': 'items',
'kind': 'kind',
'metadata': 'metadata'
}

def __init__(self, api_version=None, items=None, kind=None, metadata=None, local_vars_configuration=None): # noqa: E501
"""V1alpha1LocalModelNodeList - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration

self._api_version = None
self._items = None
self._kind = None
self._metadata = None
self.discriminator = None

if api_version is not None:
self.api_version = api_version
self.items = items
if kind is not None:
self.kind = kind
if metadata is not None:
self.metadata = metadata

@property
def api_version(self):
"""Gets the api_version of this V1alpha1LocalModelNodeList. # noqa: E501
APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources # noqa: E501
:return: The api_version of this V1alpha1LocalModelNodeList. # noqa: E501
:rtype: str
"""
return self._api_version

@api_version.setter
def api_version(self, api_version):
"""Sets the api_version of this V1alpha1LocalModelNodeList.
APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources # noqa: E501
:param api_version: The api_version of this V1alpha1LocalModelNodeList. # noqa: E501
:type: str
"""

self._api_version = api_version

@property
def items(self):
"""Gets the items of this V1alpha1LocalModelNodeList. # noqa: E501
:return: The items of this V1alpha1LocalModelNodeList. # noqa: E501
:rtype: list[V1alpha1LocalModelNode]
"""
return self._items

@items.setter
def items(self, items):
"""Sets the items of this V1alpha1LocalModelNodeList.
:param items: The items of this V1alpha1LocalModelNodeList. # noqa: E501
:type: list[V1alpha1LocalModelNode]
"""
if self.local_vars_configuration.client_side_validation and items is None: # noqa: E501
raise ValueError("Invalid value for `items`, must not be `None`") # noqa: E501

self._items = items

@property
def kind(self):
"""Gets the kind of this V1alpha1LocalModelNodeList. # noqa: E501
Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds # noqa: E501
:return: The kind of this V1alpha1LocalModelNodeList. # noqa: E501
:rtype: str
"""
return self._kind

@kind.setter
def kind(self, kind):
"""Sets the kind of this V1alpha1LocalModelNodeList.
Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds # noqa: E501
:param kind: The kind of this V1alpha1LocalModelNodeList. # noqa: E501
:type: str
"""

self._kind = kind

@property
def metadata(self):
"""Gets the metadata of this V1alpha1LocalModelNodeList. # noqa: E501
:return: The metadata of this V1alpha1LocalModelNodeList. # noqa: E501
:rtype: V1ListMeta
"""
return self._metadata

@metadata.setter
def metadata(self, metadata):
"""Sets the metadata of this V1alpha1LocalModelNodeList.
:param metadata: The metadata of this V1alpha1LocalModelNodeList. # noqa: E501
:type: V1ListMeta
"""

self._metadata = metadata

def to_dict(self):
"""Returns the model properties as a dict"""
result = {}

for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value

return result

def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())

def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()

def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, V1alpha1LocalModelNodeList):
return False

return self.to_dict() == other.to_dict()

def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, V1alpha1LocalModelNodeList):
return True

return self.to_dict() != other.to_dict()
137 changes: 137 additions & 0 deletions python/kserve/kserve/models/v1alpha1_local_model_node_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


import pprint
import re # noqa: F401

import six

from kserve.configuration import Configuration


class V1alpha1LocalModelNodeSpec(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""

"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'local_models': 'list[V1alpha1LocalModelInfo]'
}

attribute_map = {
'local_models': 'localModels'
}

def __init__(self, local_models=None, local_vars_configuration=None): # noqa: E501
"""V1alpha1LocalModelNodeSpec - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration

self._local_models = None
self.discriminator = None

self.local_models = local_models

@property
def local_models(self):
"""Gets the local_models of this V1alpha1LocalModelNodeSpec. # noqa: E501
List of model source URI and their names # noqa: E501
:return: The local_models of this V1alpha1LocalModelNodeSpec. # noqa: E501
:rtype: list[V1alpha1LocalModelInfo]
"""
return self._local_models

@local_models.setter
def local_models(self, local_models):
"""Sets the local_models of this V1alpha1LocalModelNodeSpec.
List of model source URI and their names # noqa: E501
:param local_models: The local_models of this V1alpha1LocalModelNodeSpec. # noqa: E501
:type: list[V1alpha1LocalModelInfo]
"""
if self.local_vars_configuration.client_side_validation and local_models is None: # noqa: E501
raise ValueError("Invalid value for `local_models`, must not be `None`") # noqa: E501

self._local_models = local_models

def to_dict(self):
"""Returns the model properties as a dict"""
result = {}

for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value

return result

def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())

def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()

def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, V1alpha1LocalModelNodeSpec):
return False

return self.to_dict() == other.to_dict()

def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, V1alpha1LocalModelNodeSpec):
return True

return self.to_dict() != other.to_dict()
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from typing import AsyncIterator, Iterable, Union, cast
from typing import AsyncIterator, Iterable, Union, cast, Optional

from kserve.protocol.rest.openai.types import (
ChatCompletion,
@@ -53,7 +53,9 @@ class OpenAIChatAdapterModel(OpenAIModel):

@abstractmethod
def apply_chat_template(
self, messages: Iterable[ChatCompletionRequestMessage]
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
@@ -193,15 +195,16 @@ def completion_to_chat_completion_chunk(
)

async def create_chat_completion(
self, request: ChatCompletionRequest
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
params = request.params

if params.n != 1:
raise InvalidInput("n != 1 is not supported")

# Convert the messages into a prompt
chat_prompt = self.apply_chat_template(params.messages)
chat_prompt = self.apply_chat_template(params.messages, params.chat_template)
# Translate the chat completion request to a completion request
completion_params = self.chat_completion_params_to_completion_params(
params, chat_prompt.prompt
5 changes: 3 additions & 2 deletions python/kserve/kserve/protocol/rest/openai/types/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Steps to generate


```bash
curl https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml -o openapi-2.0.0.yaml
datamodel-codegen --input openapi-2.0.0.yaml --input-file-type openapi --output openapi.py --output-model-type pydantic_v2.BaseModel --use-double-quotes --collapse-root-models --enum-field-as-literal all --strict-nullable
datamodel-codegen --input openapi-2.0.0.yaml --input-file-type openapi --output openapi.py --output-model-type pydantic_v2.BaseModel --use-double-quotes --collapse-root-models --enum-field-as-literal all --strict-nullable```
Adapted from the generated `openapi.py`
13 changes: 13 additions & 0 deletions python/kserve/kserve/protocol/rest/openai/types/openapi.py
Original file line number Diff line number Diff line change
@@ -2705,6 +2705,19 @@ class CreateChatCompletionRequest(BaseModel):
max_length=128,
min_length=1,
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)


class RunStepObject(BaseModel):
45 changes: 30 additions & 15 deletions python/kserve/kserve/storage/storage.py
Original file line number Diff line number Diff line change
@@ -319,6 +319,7 @@ def _download_hf(uri, temp_dir: str) -> str:
def _download_gcs(uri, temp_dir: str) -> str:
from google.auth import exceptions
from google.cloud import storage
import copy

try:
storage_client = storage.Client()
@@ -333,22 +334,36 @@ def _download_gcs(uri, temp_dir: str) -> str:
prefix = prefix + "/"
blobs = bucket.list_blobs(prefix=prefix)
file_count = 0
for blob in blobs:
# Replace any prefix from the object key with temp_dir
subdir_object_key = blob.name.replace(bucket_path, "", 1).lstrip("/")

# Create necessary subdirectory to store the object locally
if "/" in subdir_object_key:
local_object_dir = os.path.join(
temp_dir, subdir_object_key.rsplit("/", 1)[0]
)
if not os.path.isdir(local_object_dir):
os.makedirs(local_object_dir, exist_ok=True)
if subdir_object_key.strip() != "" and not subdir_object_key.endswith("/"):
dest_path = os.path.join(temp_dir, subdir_object_key)
logger.info("Downloading: %s", dest_path)
blob.download_to_filename(dest_path)
file_count += 1
# Shallow copy, otherwise Iterator has already started
shallow_blobs = copy.copy(blobs)
blob = bucket.blob(bucket_path)
# checks if the blob is a file or a directory
if blob.name == bucket_path and len(list(shallow_blobs)) == 0:
dest_path = os.path.join(temp_dir, os.path.basename(bucket_path))
logger.info("Downloading single file to: %s", dest_path)
blob.download_to_filename(dest_path)
file_count = 1

else:
for blob in blobs:
# Replace any prefix from the object key with temp_dir
subdir_object_key = blob.name.replace(bucket_path, "", 1).lstrip("/")
# Create necessary subdirectory to store the object locally
if "/" in subdir_object_key:
local_object_dir = os.path.join(
temp_dir, subdir_object_key.rsplit("/", 1)[0]
)
if not os.path.isdir(local_object_dir):
os.makedirs(local_object_dir, exist_ok=True)
if subdir_object_key.strip() != "" and not subdir_object_key.endswith(
"/"
):
dest_path = os.path.join(temp_dir, subdir_object_key)
logger.info("Downloading: %s", dest_path)
blob.download_to_filename(dest_path)
file_count += 1

if file_count == 0:
raise RuntimeError("Failed to fetch model. No model found in %s." % uri)

16 changes: 16 additions & 0 deletions python/kserve/kserve/storage/test/test_gcs_storage.py
Original file line number Diff line number Diff line change
@@ -105,6 +105,22 @@ def test_download_model_from_gcs(mock_client):
assert "/mock.object" in arg_list[0][0]


@mock.patch("google.cloud.storage.Client")
def test_download_model_from_gcs_as_single_file(mock_client):
gcs_path = "gs://foo/bar/mock.object"
mock_file = create_mock_dir_with_file("bar", "mock.object")

mock_bucket = mock.MagicMock()
mock_bucket.blob.return_value = mock_file
mock_file.exists.return_value = True
mock_client.return_value.bucket.return_value = mock_bucket

Storage.download(gcs_path)
arg_list = get_call_args(mock_file.download_to_filename.call_args_list)

assert "/mock.object" in arg_list[0][0]


@mock.patch("os.remove")
@mock.patch("os.mkdir")
@mock.patch("zipfile.ZipFile")
69 changes: 67 additions & 2 deletions python/kserve/poetry.lock
2 changes: 1 addition & 1 deletion python/kserve/pyproject.toml
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ azure-storage-blob = { version = "^12.20.0", optional = true }
azure-storage-file-share = { version = "^12.16.0", optional = true }
azure-identity = { version = "^1.15.0", optional = true }
boto3 = { version = "^1.29.0", optional = true }
huggingface-hub = { version = "^0.24.5", optional = true }
huggingface-hub = { version = "^0.24.5", extras = ["hf-transfer"], optional = true }

# Logging dependencies. They can be opted into by apps.
asgi-logger = { version = "^0.1.0", optional = true }
3 changes: 2 additions & 1 deletion python/kserve/test/test_openai.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncIterator, Callable, Iterable, List, Tuple, Union, cast
from typing import AsyncIterator, Callable, Iterable, List, Tuple, Union, cast, Optional
from unittest.mock import MagicMock, patch

import httpx
@@ -86,6 +86,7 @@ async def create_completion(
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
return ChatPrompt(prompt="hello")

72 changes: 72 additions & 0 deletions python/kserve/test/test_v1alpha1_local_model_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


from __future__ import absolute_import

import unittest
import datetime

import kserve
from kserve.models.v1alpha1_local_model_node import V1alpha1LocalModelNode # noqa: E501
from kserve.rest import ApiException


class TestV1alpha1LocalModelNode(unittest.TestCase):
"""V1alpha1LocalModelNode unit test stubs"""

def setUp(self):
pass

def tearDown(self):
pass

def make_instance(self, include_optional):
"""Test V1alpha1LocalModelNode
include_option is a boolean, when False only required
params are included, when True both required and
optional params are included"""
# model = kserve.models.v1alpha1_local_model_node.V1alpha1LocalModelNode() # noqa: E501
if include_optional:
return V1alpha1LocalModelNode(
api_version="0",
kind="0",
metadata=None,
spec=kserve.models.v1alpha1_local_model_node_spec.V1alpha1LocalModelNodeSpec(
local_models=[None],
),
status=None,
)
else:
return V1alpha1LocalModelNode()

def testV1alpha1LocalModelNode(self):
"""Test V1alpha1LocalModelNode"""
inst_req_only = self.make_instance(include_optional=False)
inst_req_and_optional = self.make_instance(include_optional=True)


if __name__ == "__main__":
unittest.main()
93 changes: 93 additions & 0 deletions python/kserve/test/test_v1alpha1_local_model_node_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


from __future__ import absolute_import

import unittest
import datetime

import kserve
from kserve.models.v1alpha1_local_model_node_list import (
V1alpha1LocalModelNodeList,
) # noqa: E501
from kserve.rest import ApiException


class TestV1alpha1LocalModelNodeList(unittest.TestCase):
"""V1alpha1LocalModelNodeList unit test stubs"""

def setUp(self):
pass

def tearDown(self):
pass

def make_instance(self, include_optional):
"""Test V1alpha1LocalModelNodeList
include_option is a boolean, when False only required
params are included, when True both required and
optional params are included"""
# model = kserve.models.v1alpha1_local_model_node_list.V1alpha1LocalModelNodeList() # noqa: E501
if include_optional:
return V1alpha1LocalModelNodeList(
api_version="0",
items=[
kserve.models.v1alpha1_local_model_node.V1alpha1LocalModelNode(
api_version="0",
kind="0",
metadata=None,
spec=kserve.models.v1alpha1_local_model_node_spec.V1alpha1LocalModelNodeSpec(
local_models=[None],
),
status=None,
)
],
kind="0",
metadata=None,
)
else:
return V1alpha1LocalModelNodeList(
items=[
kserve.models.v1alpha1_local_model_node.V1alpha1LocalModelNode(
api_version="0",
kind="0",
metadata=None,
spec=kserve.models.v1alpha1_local_model_node_spec.V1alpha1LocalModelNodeSpec(
local_models=[None],
),
status=None,
)
],
)

def testV1alpha1LocalModelNodeList(self):
"""Test V1alpha1LocalModelNodeList"""
inst_req_only = self.make_instance(include_optional=False)
inst_req_and_optional = self.make_instance(include_optional=True)


if __name__ == "__main__":
unittest.main()
68 changes: 68 additions & 0 deletions python/kserve/test/test_v1alpha1_local_model_node_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 The KServe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8

"""
KServe
Python SDK for KServe # noqa: E501
The version of the OpenAPI document: v0.1
Generated by: https://openapi-generator.tech
"""


from __future__ import absolute_import

import unittest
import datetime

import kserve
from kserve.models.v1alpha1_local_model_node_spec import (
V1alpha1LocalModelNodeSpec,
) # noqa: E501
from kserve.rest import ApiException


class TestV1alpha1LocalModelNodeSpec(unittest.TestCase):
"""V1alpha1LocalModelNodeSpec unit test stubs"""

def setUp(self):
pass

def tearDown(self):
pass

def make_instance(self, include_optional):
"""Test V1alpha1LocalModelNodeSpec
include_option is a boolean, when False only required
params are included, when True both required and
optional params are included"""
# model = kserve.models.v1alpha1_local_model_node_spec.V1alpha1LocalModelNodeSpec() # noqa: E501
if include_optional:
return V1alpha1LocalModelNodeSpec(local_models=[None])
else:
return V1alpha1LocalModelNodeSpec(
local_models=[None],
)

def testV1alpha1LocalModelNodeSpec(self):
"""Test V1alpha1LocalModelNodeSpec"""
inst_req_only = self.make_instance(include_optional=False)
inst_req_and_optional = self.make_instance(include_optional=True)


if __name__ == "__main__":
unittest.main()
Binary file removed python/storage-initializer/scripts/model.joblib
Binary file not shown.
Loading

0 comments on commit cd93dda

Please sign in to comment.