From 3c6f2ef3bda8c239c0616a8243753d8e5c3a6a0c Mon Sep 17 00:00:00 2001 From: RUFFY-369 Date: Tue, 31 Dec 2024 23:38:25 +0530 Subject: [PATCH] feat:add initial implementation of i2i generic tasks integration --- core/ai.go | 1 + core/ai_test.go | 8 ++ core/ai_worker.go | 48 ++++++++++++ core/capabilities.go | 3 + server/ai_http.go | 30 ++++++- server/ai_mediaserver.go | 2 + server/ai_process.go | 165 +++++++++++++++++++++++++++++++++++++++ server/ai_worker.go | 17 ++++ server/ai_worker_test.go | 38 ++++++++- server/rpc.go | 1 + server/rpc_test.go | 6 ++ 11 files changed, 316 insertions(+), 3 deletions(-) diff --git a/core/ai.go b/core/ai.go index a9eeae9f7..ba8526f11 100644 --- a/core/ai.go +++ b/core/ai.go @@ -28,6 +28,7 @@ type AI interface { ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error) LiveVideoToVideo(context.Context, worker.GenLiveVideoToVideoJSONRequestBody) (*worker.LiveVideoToVideoResponse, error) + ImageToImageGeneric(context.Context, worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(string, string) bool diff --git a/core/ai_test.go b/core/ai_test.go index 3e4ab8207..6ce4b6ae0 100644 --- a/core/ai_test.go +++ b/core/ai_test.go @@ -667,6 +667,14 @@ func (a *stubAIWorker) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV return &worker.LiveVideoToVideoResponse{}, nil } +func (a *stubAIWorker) ImageToImageGeneric(ctx context.Context, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) { + return &worker.ImageResponse{ + Images: []worker.Media{ + {Url: "http://example.com/image.png"}, + }, + }, nil +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { return nil } diff --git a/core/ai_worker.go b/core/ai_worker.go index 235338fca..28b32def5 100644 --- a/core/ai_worker.go +++ b/core/ai_worker.go @@ -884,6 +884,50 @@ func (orch *orchestrator) TextToSpeech(ctx context.Context, requestID string, re return res.Results, nil } +func (orch *orchestrator) ImageToImageGeneric(ctx context.Context, requestID string, req worker.GenImageToImageGenericMultipartRequestBody) (interface{}, error) { + // local AIWorker processes job if combined orchestrator/ai worker + if orch.node.AIWorker != nil { + workerResp, err := orch.node.ImageToImageGeneric(ctx, req) + if err == nil { + return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "image/png") + } else { + clog.Errorf(ctx, "Error processing with local ai worker err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "image-to-image-generic", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + } + + // remote ai worker proceses job + imgBytes, err := req.Image.Bytes() + if err != nil { + return nil, err + } + + inputUrl, err := orch.SaveAIRequestInput(ctx, requestID, imgBytes) + if err != nil { + return nil, err + } + req.Image.InitFromBytes(nil, "") // remove image data + + res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "image-to-image-generic", *req.ModelId, inputUrl, AIJobRequestData{Request: req, InputUrl: inputUrl}) + if err != nil { + return nil, err + } + + res, err = orch.node.saveRemoteAIWorkerResults(ctx, res, requestID) + if err != nil { + clog.Errorf(ctx, "Error processing with local ai worker err=%q", err) + if monitor.Enabled { + monitor.AIResultSaveError(ctx, "image-to-image-generic", *req.ModelId, string(monitor.SegmentUploadErrorUnknown)) + } + return nil, err + } + + return res.Results, nil +} + // only used for sending work to remote AI worker func (orch *orchestrator) SaveAIRequestInput(ctx context.Context, requestID string, fileData []byte) (string, error) { node := orch.node @@ -1062,6 +1106,10 @@ func (n *LivepeerNode) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV return n.AIWorker.LiveVideoToVideo(ctx, req) } +func (n *LivepeerNode) ImageToImageGeneric(ctx context.Context, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) { + return n.AIWorker.ImageToImageGeneric(ctx, req) +} + // transcodeFrames converts a series of image URLs into a video segment for the image-to-video pipeline. func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, urls []string, inProfile ffmpeg.VideoProfile, outProfile ffmpeg.VideoProfile) *TranscodeResult { ctx = clog.AddOrchSessionID(ctx, sessionID) diff --git a/core/capabilities.go b/core/capabilities.go index d2425fa98..83d71b2d9 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -83,6 +83,7 @@ const ( Capability_ImageToText Capability = 34 Capability_LiveVideoToVideo Capability = 35 Capability_TextToSpeech Capability = 36 + Capability_ImageToImageGeneric Capability = 37 ) var CapabilityNameLookup = map[Capability]string{ @@ -124,6 +125,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToText: "Image to text", Capability_LiveVideoToVideo: "Live video to video", Capability_TextToSpeech: "Text to speech", + Capability_ImageToImageGeneric: "Image to image generic", } var CapabilityTestLookup = map[Capability]CapabilityTest{ @@ -217,6 +219,7 @@ func OptionalCapabilities() []Capability { Capability_SegmentAnything2, Capability_ImageToText, Capability_TextToSpeech, + Capability_ImageToImageGeneric, } } diff --git a/server/ai_http.go b/server/ai_http.go index f738f3df0..b5bcf1c87 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -71,6 +71,7 @@ func startAIServer(lp *lphttp) error { lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody]))) lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody]))) lp.transRPC.Handle("/live-video-to-video", oapiReqValidator(lp.StartLiveVideoToVideo())) + lp.transRPC.Handle("/image-to-image-generic", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToImageGenericMultipartRequestBody]))) // Additionally, there is the '/aiResults' endpoint registered in server/rpc.go return nil @@ -470,6 +471,31 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request // TTS pricing is typically in characters, including punctuation. words := utf8.RuneCountInString(*v.Text) outPixels = int64(1000 * words) + case worker.GenImageToImageGenericMultipartRequestBody: + pipeline = "image-to-image-generic" + cap = core.Capability_ImageToImageGeneric + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.ImageToImageGeneric(ctx, requestID, v) + } + + imageRdr, err := v.Image.Reader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + config, _, err := image.DecodeConfig(imageRdr) + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + // NOTE: Should be enforced by the gateway, added for backwards compatibility. + numImages := int64(1) + if v.NumImagesPerPrompt != nil { + numImages = int64(*v.NumImagesPerPrompt) + } + + outPixels = int64(config.Height) * int64(config.Width) * numImages default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -575,6 +601,8 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request latencyScore = CalculateImageToTextLatencyScore(took, outPixels) case worker.GenTextToSpeechJSONRequestBody: latencyScore = CalculateTextToSpeechLatencyScore(took, outPixels) + case worker.GenImageToImageGenericMultipartRequestBody: + latencyScore = CalculateImageToImageGenericLatencyScore(took, v, outPixels) } var pricePerAIUnit float64 @@ -767,7 +795,7 @@ func parseMultiPartResult(body io.Reader, boundary string, pipeline string) core if p.Header.Get("Content-Type") == "application/json" { var results interface{} switch pipeline { - case "text-to-image", "image-to-image", "upscale", "image-to-video": + case "text-to-image", "image-to-image", "upscale", "image-to-video", "image-to-image-generic": var parsedResp worker.ImageResponse err := json.Unmarshal(body, &parsedResp) diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 80f92948b..3bc44ca0b 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -90,6 +90,8 @@ func startAIMediaServer(ls *LivepeerServer) error { // Stream status ls.HTTPMux.Handle("/live/video-to-video/{streamId}/status", ls.GetLiveVideoToVideoStatus()) + ls.HTTPMux.Handle("/image-to-image-generic", oapiReqValidator(aiMediaServerHandle(ls, multipartDecoder[worker.GenImageToImageGenericMultipartRequestBody], processImageToImageGeneric))) + return nil } diff --git a/server/ai_process.go b/server/ai_process.go index ea15bd43e..d7f801bb1 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -38,6 +38,7 @@ const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" const defaultImageToTextModelID = "Salesforce/blip-image-captioning-large" const defaultLiveVideoToVideoModelID = "noop" const defaultTextToSpeechModelID = "parler-tts/parler-tts-large-v1" +const defaultImageToImageGenericModelID = "{'inpainting':'kandinsky-community/kandinsky-2-2-decoder-inpaint', 'outpainting':'destitech/controlnet-inpaint-dreamer-sdxl', 'sketch_to_image':'xinsir/controlnet-scribble-sdxl-1.0'}" var errWrongFormat = fmt.Errorf("result not in correct format") @@ -1348,6 +1349,160 @@ func processImageToText(ctx context.Context, params aiRequestParams, req worker. return txtResp, nil } +// CalculateImageToImageGenericLatencyScore computes the time taken per pixel for an image-to-image-generic request. +func CalculateImageToImageGenericLatencyScore(took time.Duration, req worker.GenImageToImageGenericMultipartRequestBody, outPixels int64) float64 { + if outPixels <= 0 { + return 0 + } + + // TODO: Default values for the number of inference steps is currently hardcoded. + // These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details. + numInferenceSteps := float64(100) + if req.NumInferenceSteps != nil { + numInferenceSteps = math.Max(1, float64(*req.NumInferenceSteps)) + } + // Handle special case for SDXL-Lightning model. (In generic case it is valid if any user uses it for stage2 pipeline for outpainting task) + if strings.HasPrefix(*req.ModelId, "ByteDance/SDXL-Lightning") { + numInferenceSteps = math.Max(1, core.ParseStepsFromModelID(req.ModelId, 8)) + } + + return took.Seconds() / float64(outPixels) / numInferenceSteps +} + +func processImageToImageGeneric(ctx context.Context, params aiRequestParams, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + imgResp, ok := resp.(*worker.ImageResponse) + if !ok { + return nil, errWrongFormat + } + + newMedia := make([]worker.Media, len(imgResp.Images)) + for i, media := range imgResp.Images { + var result []byte + var data bytes.Buffer + var name string + writer := bufio.NewWriter(&data) + err := worker.ReadImageB64DataUrl(media.Url, writer) + if err == nil { + // orchestrator sent bae64 encoded result in .Url + name = string(core.RandomManifestID()) + ".png" + writer.Flush() + result = data.Bytes() + } else { + // orchestrator sent download url, get the data + name = filepath.Base(media.Url) + result, err = core.DownloadData(ctx, media.Url) + if err != nil { + return nil, err + } + } + + newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(result), nil, 0) + if err != nil { + return nil, fmt.Errorf("error saving image to objectStore: %w", err) + } + + newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl} + } + + imgResp.Images = newMedia + + return imgResp, nil +} + +func submitImageToImageGeneric(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) { + // TODO: Default values for the number of images is currently hardcoded. + // These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details. + defaultNumImages := 1 + if req.NumImagesPerPrompt == nil { + req.NumImagesPerPrompt = &defaultNumImages + } else { + *req.NumImagesPerPrompt = int(math.Max(1, float64(*req.NumImagesPerPrompt))) + } + + var buf bytes.Buffer + mw, err := worker.NewImageToImageGenericMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + imageRdr, err := req.Image.Reader() + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + config, _, err := image.DecodeConfig(imageRdr) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + outPixels := int64(config.Height) * int64(config.Width) * int64(*req.NumImagesPerPrompt) + + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.GenImageToImageGenericWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) + took := time.Since(start) + + // TODO: Refine this rough estimate in future iterations. + sess.LatencyScore = CalculateImageToImageGenericLatencyScore(took, req, outPixels) + + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "image-to-image-generic", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + if resp.JSON200 == nil { + // TODO: Replace trim newline with better error spec from O + return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n")) + } + + // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update + if balUpdate != nil { + balUpdate.Status = ReceivedChange + } + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + + monitor.AIRequestFinished(ctx, "image-to-image-generic", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return resp.JSON200, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -1451,6 +1606,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitLiveVideoToVideo(ctx, params, sess, v) } + case worker.GenImageToImageGenericMultipartRequestBody: + cap = core.Capability_ImageToImageGeneric + modelID = defaultImageToImageGenericModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitImageToImageGeneric(ctx, params, sess, v) + } + ctx = clog.AddVal(ctx, "prompt", v.Prompt) default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/ai_worker.go b/server/ai_worker.go index 34dc722bf..10066527d 100644 --- a/server/ai_worker.go +++ b/server/ai_worker.go @@ -314,6 +314,23 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify return n.TextToSpeech(ctx, req) } reqOk = true + case "image-to-image-generic": + var req worker.GenImageToImageGenericMultipartRequestBody + err = json.Unmarshal(reqData.Request, &req) + if err != nil || req.ModelId == nil { + break + } + input, err = core.DownloadData(ctx, reqData.InputUrl) + if err != nil { + break + } + modelID = *req.ModelId + resultType = "image/png" + req.Image.InitFromBytes(input, "image") + processFn = func(ctx context.Context) (interface{}, error) { + return n.ImageToImageGeneric(ctx, req) + } + reqOk = true default: err = errors.New("AI request pipeline type not supported") } diff --git a/server/ai_worker_test.go b/server/ai_worker_test.go index ab31a3e71..baaaf4c14 100644 --- a/server/ai_worker_test.go +++ b/server/ai_worker_test.go @@ -218,16 +218,23 @@ func TestRunAIJob(t *testing.T) { expectedErr: "", expectedOutputs: 1, }, + { + name: "ImageToImageGeneric_Success", + notify: createAIJob(10, "image-to-image-generic", modelId, parsedURL.String()+"/image.png"), + pipeline: "image-to-image-generic", + expectedErr: "", + expectedOutputs: 1, + }, { name: "UnsupportedPipeline", - notify: createAIJob(10, "unsupported-pipeline", modelId, ""), + notify: createAIJob(11, "unsupported-pipeline", modelId, ""), pipeline: "unsupported-pipeline", expectedErr: "AI request validation failed for", expectedOutputs: 0, }, { name: "InvalidRequestData", - notify: createAIJob(11, "text-to-image-invalid", modelId, ""), + notify: createAIJob(12, "text-to-image-invalid", modelId, ""), pipeline: "text-to-image", expectedErr: "AI request validation failed for", expectedOutputs: 0, @@ -344,6 +351,13 @@ func TestRunAIJob(t *testing.T) { var respFile bytes.Buffer worker.ReadAudioB64DataUrl(expectedResp.Audio.Url, &respFile) assert.Equal(len(results.Files[audResp.Audio.Url]), respFile.Len()) + case "image-to-image-generic": + i2iResp, ok := results.Results.(worker.ImageResponse) + assert.True(ok) + assert.Equal("10", headers.Get("TaskId")) + assert.Equal(len(results.Files), 1) + expectedResp, _ := wkr.ImageToImageGeneric(context.Background(), worker.GenImageToImageGenericMultipartRequestBody{}) + assert.Equal(expectedResp.Images[0].Seed, i2iResp.Images[0].Seed) } } }) @@ -380,6 +394,9 @@ func createAIJob(taskId int64, pipeline, modelId, inputUrl string) *net.NotifyAI desc := "a young adult" text := "let me tell you a story" req = worker.GenTextToSpeechJSONRequestBody{Description: &desc, ModelId: &modelId, Text: &text} + case "image-to-image-generic": + inputFile.InitFromBytes(nil, inputUrl) + req = worker.GenImageToImageGenericMultipartRequestBody{Prompt: "test prompt", ModelId: &modelId, Image: inputFile} case "unsupported-pipeline": req = worker.GenTextToImageJSONRequestBody{Prompt: "test prompt", ModelId: &modelId} case "text-to-image-invalid": @@ -635,6 +652,23 @@ func (a *stubAIWorker) LiveVideoToVideo(ctx context.Context, req worker.GenLiveV } } +func (a *stubAIWorker) ImageToImageGeneric(ctx context.Context, req worker.GenImageToImageGenericMultipartRequestBody) (*worker.ImageResponse, error) { + a.Called++ + if a.Err != nil { + return nil, a.Err + } else { + return &worker.ImageResponse{ + Images: []worker.Media{ + { + Url: "", + Nsfw: false, + Seed: 115, + }, + }, + }, nil + } +} + func (a *stubAIWorker) Warm(ctx context.Context, arg1, arg2 string, endpoint worker.RunnerEndpoint, flags worker.OptimizationFlags) error { a.Called++ return nil diff --git a/server/rpc.go b/server/rpc.go index 7223c56a9..bd27405fd 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -79,6 +79,7 @@ type Orchestrator interface { ImageToText(ctx context.Context, requestID string, req worker.GenImageToTextMultipartRequestBody) (interface{}, error) TextToSpeech(ctx context.Context, requestID string, req worker.GenTextToSpeechJSONRequestBody) (interface{}, error) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) + ImageToImageGeneric(ctx context.Context, requestID string, req worker.GenImageToImageGenericMultipartRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance diff --git a/server/rpc_test.go b/server/rpc_test.go index 43ec1a304..825ce0363 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -226,6 +226,9 @@ func (r *stubOrchestrator) TextToSpeech(ctx context.Context, requestID string, r func (r *stubOrchestrator) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) { return nil, nil } +func (r *stubOrchestrator) ImageToImageGeneric(ctx context.Context, requestID string, req worker.GenImageToImageGenericMultipartRequestBody) (interface{}, error) { + return nil, nil +} func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true @@ -1432,6 +1435,9 @@ func (r *mockOrchestrator) TextToSpeech(ctx context.Context, requestID string, r func (r *mockOrchestrator) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) { return nil, nil } +func (r *mockOrchestrator) ImageToImageGeneric(ctx context.Context, requestID string, req worker.GenImageToImageGenericMultipartRequestBody) (interface{}, error) { + return nil, nil +} func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool { return true }