Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI] Integrate Image-to-Image tasks for a generic image-to-image pipeline #3337

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type AI interface {
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)
LiveVideoToVideo(context.Context, worker.GenLiveVideoToVideoJSONRequestBody) (*worker.LiveVideoToVideoResponse, error)
ImageToImageGeneric(context.Context, worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(string, string) bool
Expand Down
8 changes: 8 additions & 0 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,14 @@ func (a *stubAIWorker) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV
return &worker.LiveVideoToVideoResponse{}, nil
}

func (a *stubAIWorker) ImageToImageGeneric(ctx context.Context, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) {
return &worker.ImageResponse{
Images: []worker.Media{
{Url: "http://example.com/image.png"},
},
}, nil
}

func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error {
return nil
}
Expand Down
48 changes: 48 additions & 0 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,50 @@ func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, re
return res.Results, nil
}

func (orch *orchestrator) ImageToImageGeneric(ctx context.Context, requestID string, req worker.GenImageToImageGenericMultipartRequestBody) (interface{}, error) {
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.ImageToImageGeneric(ctx, req)
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "image/png")
} else {
clog.Errorf(ctx, "Error processing with local ai worker err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "image-to-image-generic", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}
}

// remote ai worker proceses job
imgBytes, err := req.Image.Bytes()
if err != nil {
return nil, err
}

inputUrl, err := orch.SaveAIRequestInput(ctx, requestID, imgBytes)
if err != nil {
return nil, err
}
req.Image.InitFromBytes(nil, "") // remove image data

res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "image-to-image-generic", *req.ModelId, inputUrl, AIJobRequestData{Request: req, InputUrl: inputUrl})
if err != nil {
return nil, err
}

res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID)
if err != nil {
clog.Errorf(ctx, "Error processing with local ai worker err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "image-to-image-generic", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err
}

return res.Results, nil
}

// only used for sending work to remote AI worker
func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) {
node := orch.node
Expand Down Expand Up @@ -1062,6 +1106,10 @@ func (n *LivepeerNode) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV
return n.AIWorker.LiveVideoToVideo(ctx, req)
}

func (n *LivepeerNode) ImageToImageGeneric(ctx context.Context, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImageGeneric(ctx, req)
}

// transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline.
func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult {
ctx = clog.AddOrchSessionID(ctx, sessionID)
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ const (
Capability_ImageToText Capability = 34
Capability_LiveVideoToVideo Capability = 35
Capability_TextToSpeech Capability = 36
Capability_ImageToImageGeneric Capability = 37
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -124,6 +125,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToText: "Image to text",
Capability_LiveVideoToVideo: "Live video to video",
Capability_TextToSpeech: "Text to speech",
Capability_ImageToImageGeneric: "Image to image generic",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -217,6 +219,7 @@ func OptionalCapabilities() []Capability {
Capability_SegmentAnything2,
Capability_ImageToText,
Capability_TextToSpeech,
Capability_ImageToImageGeneric,
}
}

Expand Down
30 changes: 29 additions & 1 deletion server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func startAIServer(lp *lphttp) error {
lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody])))
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody])))
lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo()))
lp.transRPC.Handle("/image-to-image-generic", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToImageGenericMultipartRequestBody])))
// Additionally, there is the '/aiResults' endpoint registered in server/rpc.go

return nil
Expand Down Expand Up @@ -470,6 +471,31 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
// TTS pricing is typically in characters, including punctuation.
words := utf8.RuneCountInString(*v.Text)
outPixels = int64(1000 * words)
case worker.GenImageToImageGenericMultipartRequestBody:
pipeline = "image-to-image-generic"
cap = core.Capability_ImageToImageGeneric
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.ImageToImageGeneric(ctx, requestID, v)
}

imageRdr, err := v.Image.Reader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
// NOTE: Should be enforced by the gateway, added for backwards compatibility.
numImages := int64(1)
if v.NumImagesPerPrompt != nil {
numImages = int64(*v.NumImagesPerPrompt)
}

outPixels = int64(config.Height) * int64(config.Width) * numImages
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down Expand Up @@ -575,6 +601,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
latencyScore = CalculateImageToTextLatencyScore(took, outPixels)
case worker.GenTextToSpeechJSONRequestBody:
latencyScore = CalculateTextToSpeechLatencyScore(took, outPixels)
case worker.GenImageToImageGenericMultipartRequestBody:
latencyScore = CalculateImageToImageGenericLatencyScore(took, v, outPixels)
}

var pricePerAIUnit float64
Expand Down Expand Up @@ -767,7 +795,7 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core
if p.Header.Get("Content-Type") == "application/json" {
var results interface{}
switch pipeline {
case "text-to-image", "image-to-image", "upscale", "image-to-video":
case "text-to-image", "image-to-image", "upscale", "image-to-video", "image-to-image-generic":
var parsedResp worker.ImageResponse

err := json.Unmarshal(body, &parsedResp)
Expand Down
2 changes: 2 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func startAIMediaServer(ls *LivepeerServer) error {
// Stream status
ls.HTTPMux.Handle("/live/video-to-video/{streamId}/status", ls.GetLiveVideoToVideoStatus())

ls.HTTPMux.Handle("/image-to-image-generic", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenImageToImageGenericMultipartRequestBody], processImageToImageGeneric)))

return nil
}

Expand Down
165 changes: 165 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"
const defaultImageToTextModelID = "Salesforce/blip-image-captioning-large"
const defaultLiveVideoToVideoModelID = "noop"
const defaultTextToSpeechModelID = "parler-tts/parler-tts-large-v1"
const defaultImageToImageGenericModelID = "{'inpainting':'kandinsky-community/kandinsky-2-2-decoder-inpaint', 'outpainting':'destitech/controlnet-inpaint-dreamer-sdxl', 'sketch_to_image':'xinsir/controlnet-scribble-sdxl-1.0'}"

var errWrongFormat = fmt.Errorf("result not in correct format")

Expand Down Expand Up @@ -1348,6 +1349,160 @@ func processImageToText(ctx context.Context, params aiRequestParams, req worker.
return txtResp, nil
}

// CalculateImageToImageGenericLatencyScore computes the time taken per pixel for an image-to-image-generic request.
func CalculateImageToImageGenericLatencyScore(took time.Duration, req worker.GenImageToImageGenericMultipartRequestBody, outPixels int64) float64 {
if outPixels <= 0 {
return 0
}

// TODO: Default values for the number of inference steps is currently hardcoded.
// These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details.
numInferenceSteps := float64(100)
if req.NumInferenceSteps != nil {
numInferenceSteps = math.Max(1, float64(*req.NumInferenceSteps))
}
// Handle special case for SDXL-Lightning model. (In generic case it is valid if any user uses it for stage2 pipeline for outpainting task)
if strings.HasPrefix(*req.ModelId, "ByteDance/SDXL-Lightning") {
numInferenceSteps = math.Max(1, core.ParseStepsFromModelID(req.ModelId, 8))
}

return took.Seconds() / float64(outPixels) / numInferenceSteps
}

func processImageToImageGeneric(ctx context.Context, params aiRequestParams, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

imgResp, ok := resp.(*worker.ImageResponse)
if !ok {
return nil, errWrongFormat
}

newMedia := make([]worker.Media, len(imgResp.Images))
for i, media := range imgResp.Images {
var result []byte
var data bytes.Buffer
var name string
writer := bufio.NewWriter(&data)
err := worker.ReadImageB64DataUrl(media.Url, writer)
if err == nil {
// orchestrator sent bae64 encoded result in .Url
name = string(core.RandomManifestID()) + ".png"
writer.Flush()
result = data.Bytes()
} else {
// orchestrator sent download url, get the data
name = filepath.Base(media.Url)
result, err = core.DownloadData(ctx, media.Url)
if err != nil {
return nil, err
}
}

newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(result), nil, 0)
if err != nil {
return nil, fmt.Errorf("error saving image to objectStore: %w", err)
}

newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl}
}

imgResp.Images = newMedia

return imgResp, nil
}

func submitImageToImageGeneric(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) {
// TODO: Default values for the number of images is currently hardcoded.
// These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details.
defaultNumImages := 1
if req.NumImagesPerPrompt == nil {
req.NumImagesPerPrompt = &defaultNumImages
} else {
*req.NumImagesPerPrompt = int(math.Max(1, float64(*req.NumImagesPerPrompt)))
}

var buf bytes.Buffer
mw, err := worker.NewImageToImageGenericMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

imageRdr, err := req.Image.Reader()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

outPixels := int64(config.Height) * int64(config.Width) * int64(*req.NumImagesPerPrompt)

setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.GenImageToImageGenericWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders)
took := time.Since(start)

// TODO: Refine this rough estimate in future iterations.
sess.LatencyScore = CalculateImageToImageGenericLatencyScore(took, req, outPixels)

if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

if resp.JSON200 == nil {
// TODO: Replace trim newline with better error spec from O
return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n"))
}

// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}

monitor.AIRequestFinished(ctx, "image-to-image-generic", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return resp.JSON200, nil
}

func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) {
var cap core.Capability
var modelID string
Expand Down Expand Up @@ -1451,6 +1606,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLiveVideoToVideo(ctx, params, sess, v)
}
case worker.GenImageToImageGenericMultipartRequestBody:
cap = core.Capability_ImageToImageGeneric
modelID = defaultImageToImageGenericModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitImageToImageGeneric(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)
default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down
17 changes: 17 additions & 0 deletions server/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,23 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify
return n.TextToSpeech(ctx, req)
}
reqOk = true
case "image-to-image-generic":
var req worker.GenImageToImageGenericMultipartRequestBody
err = json.Unmarshal(reqData.Request, &req)
if err != nil || req.ModelId == nil {
break
}
input, err = core.DownloadData(ctx, reqData.InputUrl)
if err != nil {
break
}
modelID = *req.ModelId
resultType = "image/png"
req.Image.InitFromBytes(input, "image")
processFn = func(ctx context.Context) (interface{}, error) {
return n.ImageToImageGeneric(ctx, req)
}
reqOk = true
default:
err = errors.New("AI request pipeline type not supported")
}
Expand Down
Loading
Loading