Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inpainting task in DS-MII #410

Merged
merged 10 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mii/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def query(self, request_dict, **query_kwargs):
elif self.task == TaskType.TEXT2IMG:
args = (request_dict["prompt"], request_dict.get("negative_prompt", None))
kwargs = query_kwargs
elif self.task == TaskType.INPAINTING:
negative_prompt = request_dict.get("negative_prompt", None)
args = (request_dict["prompt"],
request_dict["image"],
request_dict["mask_image"],
negative_prompt)
kwargs = query_kwargs
else:
args = (request_dict["query"], )
kwargs = query_kwargs
Expand Down
6 changes: 6 additions & 0 deletions mii/legacy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TaskType(str, Enum):
CONVERSATIONAL = "conversational"
TEXT2IMG = "text-to-image"
ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification"
INPAINTING = "text-to-image-inpainting"


class ModelProvider(str, Enum):
Expand Down Expand Up @@ -60,6 +61,11 @@ class ModelProvider(str, Enum):
TaskType.TEXT2IMG: ["prompt"],
TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ["image",
"candidate_labels"],
TaskType.INPAINTING: [
"prompt",
"image",
"mask_image",
]
}

MII_CACHE_PATH = "MII_CACHE_PATH"
Expand Down
3 changes: 3 additions & 0 deletions mii/legacy/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def ConversationalReply(self, request, context):
def ZeroShotImgClassificationReply(self, request, context):
return self._run_inference("ZeroShotImgClassificationReply", request)

def InpaintingReply(self, request, context):
return self._run_inference("InpaintingReply", request)


class AtomicCounter:
def __init__(self, initial_value=0):
Expand Down
9 changes: 9 additions & 0 deletions mii/legacy/grpc_related/proto/legacymodelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ service ModelResponse {
rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
rpc Txt2ImgReply(Text2ImageRequest) returns (ImageReply) {}
rpc ZeroShotImgClassificationReply (ZeroShotImgClassificationRequest) returns (SingleStringReply) {}
rpc InpaintingReply(InpaintingRequest) returns (ImageReply) {}
}

message Value {
Expand Down Expand Up @@ -114,3 +115,11 @@ message ZeroShotImgClassificationRequest {
repeated string candidate_labels = 2;
map<string,Value> query_kwargs = 3;
}

message InpaintingRequest {
repeated string prompt = 1;
repeated bytes image = 2;
repeated bytes mask_image = 3;
repeated string negative_prompt = 4;
map<string,Value> query_kwargs = 5;
}
13 changes: 9 additions & 4 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# DeepSpeed Team
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: legacymodelresponse.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
Expand All @@ -17,7 +16,7 @@
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\xb7\x08\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x62\x06proto3'
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xfe\x01\n\x11InpaintingRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\r\n\x05image\x18\x02 \x03(\x0c\x12\x12\n\nmask_image\x18\x03 \x03(\x0c\x12\x17\n\x0fnegative_prompt\x18\x04 \x03(\t\x12M\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x37.legacymodelresponse.InpaintingRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\x95\t\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\\\n\x0fInpaintingReply\x12&.legacymodelresponse.InpaintingRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x62\x06proto3'
)

_globals = globals()
Expand All @@ -38,6 +37,8 @@
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._options = None
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_VALUE']._serialized_start = 79
_globals['_VALUE']._serialized_end = 174
_globals['_SESSIONID']._serialized_start = 176
Expand Down Expand Up @@ -75,6 +76,10 @@
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_MODELRESPONSE']._serialized_start = 2009
_globals['_MODELRESPONSE']._serialized_end = 3088
_globals['_INPAINTINGREQUEST']._serialized_start = 2009
_globals['_INPAINTINGREQUEST']._serialized_end = 2263
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_MODELRESPONSE']._serialized_start = 2266
_globals['_MODELRESPONSE']._serialized_end = 3439
# @@protoc_insertion_point(module_scope)
44 changes: 44 additions & 0 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __init__(self, channel):
SerializeToString,
response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
)
self.InpaintingReply = channel.unary_unary(
'/legacymodelresponse.ModelResponse/InpaintingReply',
request_serializer=legacymodelresponse__pb2.InpaintingRequest.
SerializeToString,
response_deserializer=legacymodelresponse__pb2.ImageReply.FromString,
)


class ModelResponseServicer(object):
Expand Down Expand Up @@ -151,6 +157,12 @@ def ZeroShotImgClassificationReply(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def InpaintingReply(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_ModelResponseServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -231,6 +243,12 @@ def add_ModelResponseServicer_to_server(servicer, server):
response_serializer=legacymodelresponse__pb2.SingleStringReply.
SerializeToString,
),
'InpaintingReply':
grpc.unary_unary_rpc_method_handler(
servicer.InpaintingReply,
request_deserializer=legacymodelresponse__pb2.InpaintingRequest.FromString,
response_serializer=legacymodelresponse__pb2.ImageReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'legacymodelresponse.ModelResponse',
Expand Down Expand Up @@ -526,3 +544,29 @@ def ZeroShotImgClassificationReply(request,
wait_for_ready,
timeout,
metadata)

@staticmethod
def InpaintingReply(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/legacymodelresponse.ModelResponse/InpaintingReply',
legacymodelresponse__pb2.InpaintingRequest.SerializeToString,
legacymodelresponse__pb2.ImageReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata)
48 changes: 47 additions & 1 deletion mii/legacy/method_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mii.legacy.constants import TaskType
from mii.legacy.grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2
from mii.legacy.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
from mii.legacy.models.utils import ImageResponse
from mii.legacy.models.utils import ImageResponse, convert_bytes_to_pil_image


def single_string_request_to_proto(self, request_dict, **query_kwargs):
Expand Down Expand Up @@ -312,6 +312,51 @@ def run_inference(self, inference_pipeline, args, kwargs):
return inference_pipeline(image, candidate_labels=candidate_labels, **kwargs)


class InpaintingMethods(Text2ImgMethods):
@property
def method(self):
return "InpaintingReply"

def run_inference(self, inference_pipeline, args, kwargs):
prompt, image, mask_image, negative_prompt = args
return inference_pipeline(prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt,
**kwargs)

def pack_request_to_proto(self, request_dict, **query_kwargs):
prompt = request_dict["prompt"]
prompt = prompt if isinstance(prompt, list) else [prompt]
negative_prompt = request_dict.get("negative_prompt", [""] * len(prompt))
negative_prompt = negative_prompt if isinstance(negative_prompt,
list) else [negative_prompt]
image = request_dict["image"] if isinstance(request_dict["image"],
list) else [request_dict["image"]]
mask_image = request_dict["mask_image"] if isinstance(
request_dict["mask_image"],
list) else [request_dict["mask_image"]]

return modelresponse_pb2.InpaintingRequest(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt,
query_kwargs=kwarg_dict_to_proto(query_kwargs),
)

def unpack_request_from_proto(self, request):
kwargs = unpack_proto_query_kwargs(request.query_kwargs)

image = [convert_bytes_to_pil_image(img) for img in request.image]
mask_image = [
convert_bytes_to_pil_image(mask_image) for mask_image in request.mask_image
]

args = (list(request.prompt), image, mask_image, list(request.negative_prompt))
return args, kwargs


GRPC_METHOD_TABLE = {
TaskType.TEXT_GENERATION: TextGenerationMethods(),
TaskType.TEXT_CLASSIFICATION: TextClassificationMethods(),
Expand All @@ -321,4 +366,5 @@ def run_inference(self, inference_pipeline, args, kwargs):
TaskType.CONVERSATIONAL: ConversationalMethods(),
TaskType.TEXT2IMG: Text2ImgMethods(),
TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ZeroShotImgClassificationMethods(),
TaskType.INPAINTING: InpaintingMethods(),
}
Loading