From e06cf41705210e9c530c3ccca8dfbf35e73caf1b Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Mon, 12 Feb 2024 19:44:11 +1300 Subject: [PATCH 1/5] Wrap and simplify Image input --- .../Examples/ControlNetExample.cs | 6 +- .../Examples/ControlNetFeatureExample.cs | 8 +- .../Examples/FeatureExtractorExample.cs | 4 +- OnnxStack.Console/Examples/StableDebug.cs | 4 +- .../Examples/StableDiffusionBatch.cs | 4 +- .../Examples/StableDiffusionExample.cs | 4 +- .../Examples/StableDiffusionGenerator.cs | 4 +- OnnxStack.Console/Examples/UpscaleExample.cs | 6 +- OnnxStack.Core/Image/Extensions.cs | 501 +----------------- OnnxStack.Core/Image/InputImage.cs | 105 ---- OnnxStack.Core/Image/OnnxImage.cs | 395 ++++++++++++++ OnnxStack.Core/Video/Extensions.cs | 18 +- OnnxStack.Core/Video/VideoFrame.cs | 2 +- .../Pipelines/FeatureExtractorPipeline.cs | 12 +- .../Extensions/ImageExtensions.cs | 9 +- OnnxStack.ImageUpscaler/Models/ImageTile.cs | 5 +- .../Pipelines/ImageUpscalePipeline.cs | 40 +- .../Services/IUpscaleService.cs | 8 +- .../Services/UpscaleService.cs | 68 ++- .../Common/BatchResult.cs | 4 +- .../Config/PromptOptions.cs | 6 +- .../Diffusers/InstaFlow/ControlNetDiffuser.cs | 2 +- .../LatentConsistency/ControlNetDiffuser.cs | 2 +- .../ControlNetImageDiffuser.cs | 2 +- .../LatentConsistency/ImageDiffuser.cs | 2 +- .../InpaintLegacyDiffuser.cs | 4 +- .../LatentConsistencyXL/ControlNetDiffuser.cs | 2 +- .../ControlNetImageDiffuser.cs | 2 +- .../LatentConsistencyXL/ImageDiffuser.cs | 2 +- .../InpaintLegacyDiffuser.cs | 4 +- .../StableDiffusion/ControlNetDiffuser.cs | 2 +- .../ControlNetImageDiffuser.cs | 2 +- .../StableDiffusion/ImageDiffuser.cs | 2 +- .../StableDiffusion/InpaintDiffuser.cs | 6 +- .../StableDiffusion/InpaintLegacyDiffuser.cs | 4 +- .../StableDiffusionXL/ControlNetDiffuser.cs | 2 +- .../ControlNetImageDiffuser.cs | 2 +- .../StableDiffusionXL/ImageDiffuser.cs | 2 +- .../InpaintLegacyDiffuser.cs | 4 +- .../Pipelines/Base/PipelineBase.cs | 4 +- .../Pipelines/StableDiffusionPipeline.cs | 3 +- .../Services/IStableDiffusionService.cs | 73 +-- .../Services/StableDiffusionService.cs | 107 +--- OnnxStack.UI/Views/ImageInpaintView.xaml.cs | 16 +- OnnxStack.UI/Views/ImageToImageView.xaml.cs | 14 +- OnnxStack.UI/Views/TextToImageView.xaml.cs | 7 +- OnnxStack.UI/Views/UpscaleView.xaml.cs | 2 +- OnnxStack.UI/Views/VideoToVideoView.xaml.cs | 10 +- 48 files changed, 556 insertions(+), 941 deletions(-) delete mode 100644 OnnxStack.Core/Image/InputImage.cs create mode 100644 OnnxStack.Core/Image/OnnxImage.cs diff --git a/OnnxStack.Console/Examples/ControlNetExample.cs b/OnnxStack.Console/Examples/ControlNetExample.cs index 4881efa0..25e26bf1 100644 --- a/OnnxStack.Console/Examples/ControlNetExample.cs +++ b/OnnxStack.Console/Examples/ControlNetExample.cs @@ -32,7 +32,7 @@ public ControlNetExample(StableDiffusionConfig configuration) public async Task RunAsync() { // Load Control Image - var controlImage = await InputImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png"); + var controlImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png"); // Create ControlNet var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\openpose.onnx"); @@ -54,11 +54,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback); // Create Image from Tensor result - var image = result.ToImage(); + var image = new OnnxImage(result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"Output.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); //Unload await controlNet.UnloadAsync(); diff --git a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs index f0664280..a3bd0e45 100644 --- a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs +++ b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs @@ -32,7 +32,7 @@ public ControlNetFeatureExample(StableDiffusionConfig configuration) public async Task RunAsync() { // Load Control Image - var inputImage = await InputImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); + var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Annotation pipeline var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", true); @@ -41,7 +41,7 @@ public async Task RunAsync() var controlImage = await annotationPipeline.RunAsync(inputImage); // Save Depth Image (Debug Only) - await controlImage.Image.SaveAsPngAsync(Path.Combine(_outputDirectory, $"Depth.png")); + await controlImage.SaveAsync(Path.Combine(_outputDirectory, $"Depth.png")); // Create ControlNet var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\depth.onnx"); @@ -61,11 +61,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback); // Create Image from Tensor result - var image = result.ToImage(); + var image = new OnnxImage(result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"Output.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); //Unload await annotationPipeline.UnloadAsync(); diff --git a/OnnxStack.Console/Examples/FeatureExtractorExample.cs b/OnnxStack.Console/Examples/FeatureExtractorExample.cs index d4cfdb65..3cbc14cc 100644 --- a/OnnxStack.Console/Examples/FeatureExtractorExample.cs +++ b/OnnxStack.Console/Examples/FeatureExtractorExample.cs @@ -30,7 +30,7 @@ public FeatureExtractorExample(StableDiffusionConfig configuration) public async Task RunAsync() { // Load Control Image - var inputImage = await InputImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); + var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); var pipelines = new[] { @@ -53,7 +53,7 @@ public async Task RunAsync() OutputHelpers.WriteConsole($"Generating image", ConsoleColor.Cyan); // Save Image - await imageFeature.Image.SaveAsPngAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png")); + await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png")); OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Cyan); diff --git a/OnnxStack.Console/Examples/StableDebug.cs b/OnnxStack.Console/Examples/StableDebug.cs index c09ce66a..74da1264 100644 --- a/OnnxStack.Console/Examples/StableDebug.cs +++ b/OnnxStack.Console/Examples/StableDebug.cs @@ -69,11 +69,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback); // Create Image from Tensor result - var image = result.ToImage(); + var image = new OnnxImage(result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{schedulerOptions.SchedulerType}.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); OutputHelpers.WriteConsole($"{schedulerOptions.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green); OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index 3b624f8b..6b48dffe 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -61,11 +61,11 @@ public async Task RunAsync() await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, progressCallback: OutputHelpers.BatchProgressCallback)) { // Create Image from Tensor result - var image = result.ImageResult.ToImage(); + var image = result.ImageResult; // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{result.SchedulerOptions.Seed}.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}, Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Green); timestamp = Stopwatch.GetTimestamp(); diff --git a/OnnxStack.Console/Examples/StableDiffusionExample.cs b/OnnxStack.Console/Examples/StableDiffusionExample.cs index c422a0bd..09564d50 100644 --- a/OnnxStack.Console/Examples/StableDiffusionExample.cs +++ b/OnnxStack.Console/Examples/StableDiffusionExample.cs @@ -70,11 +70,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback); // Create Image from Tensor result - var image = result.ToImage(); + var image = new OnnxImage(result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{schedulerOptions.SchedulerType}.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}, Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Green); } diff --git a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs index 62aa0572..f11aeb8b 100644 --- a/OnnxStack.Console/Examples/StableDiffusionGenerator.cs +++ b/OnnxStack.Console/Examples/StableDiffusionGenerator.cs @@ -58,11 +58,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(promptOptions, progressCallback: OutputHelpers.ProgressCallback); // Create Image from Tensor result - var image = result.ToImage(); + var image = new OnnxImage(result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{generationPrompt.Key}.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}, Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Green); } diff --git a/OnnxStack.Console/Examples/UpscaleExample.cs b/OnnxStack.Console/Examples/UpscaleExample.cs index 0fb3d163..8c232ae9 100644 --- a/OnnxStack.Console/Examples/UpscaleExample.cs +++ b/OnnxStack.Console/Examples/UpscaleExample.cs @@ -26,7 +26,7 @@ public UpscaleExample(ImageUpscalerConfig configuration) public async Task RunAsync() { // Load Input Image - var inputImage = await InputImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); + var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Pipeline var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4); @@ -35,11 +35,11 @@ public async Task RunAsync() var result = await pipeline.RunAsync(inputImage); // Create Image from Tensor result - var image = result.ToImage(ImageNormalizeType.ZeroToOne); + var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"Upscaled.png"); - await image.SaveAsPngAsync(outputFilename); + await image.SaveAsync(outputFilename); // Unload await pipeline.UnloadAsync(); diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs index a19a38a0..8a74b4e8 100644 --- a/OnnxStack.Core/Image/Extensions.cs +++ b/OnnxStack.Core/Image/Extensions.cs @@ -1,142 +1,18 @@ -using Microsoft.ML.OnnxRuntime; -using Microsoft.ML.OnnxRuntime.Tensors; +using Microsoft.ML.OnnxRuntime.Tensors; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; -using SixLabors.ImageSharp.Processing; -using System; -using System.IO; -using System.Threading.Tasks; -using ImageSharp = SixLabors.ImageSharp.Image; namespace OnnxStack.Core.Image { public static class Extensions { - #region Image To Image - - /// - /// Converts to InputImage to an Image - /// - /// The input image. - /// - public static Image ToImage(this InputImage inputImage) - { - if (!string.IsNullOrEmpty(inputImage.ImageBase64)) - return ImageSharp.Load(Convert.FromBase64String(inputImage.ImageBase64.Split(',')[1])); - if (inputImage.ImageBytes != null) - return ImageSharp.Load(inputImage.ImageBytes); - if (inputImage.ImageStream != null) - return ImageSharp.Load(inputImage.ImageStream); - if (inputImage.ImageTensor != null) - return inputImage.ImageTensor.ToImage(); - - return inputImage.Image; - } - - - /// - /// Converts to InputImage to an Image - /// - /// The input image. - /// - public static async Task> ToImageAsync(this InputImage inputImage) - { - return await Task.Run(inputImage.ToImage); - } - - - /// - /// Resizes the specified image. - /// - /// The image. - /// The dimensions. - public static void Resize(this Image image, int height, int width, ResizeMode resizeMode = ResizeMode.Crop) - { - image.Mutate(x => - { - x.Resize(new ResizeOptions - { - Size = new Size(width, height), - Mode = resizeMode, - Sampler = KnownResamplers.Lanczos8, - Compand = true - }); - }); - } - - #endregion - - #region Tensor To Image - - /// - /// Converts to image byte array. - /// - /// The image tensor. - /// - public static byte[] ToImageBytes(this DenseTensor imageTensor, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - using (var image = imageTensor.ToImage(imageNormalizeType)) - using (var memoryStream = new MemoryStream()) - { - image.SaveAsPng(memoryStream); - return memoryStream.ToArray(); - } - } - - /// - /// Converts to image byte array. - /// - /// The image tensor. - /// - public static async Task ToImageBytesAsync(this DenseTensor imageTensor, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - using (var image = imageTensor.ToImage(imageNormalizeType)) - using (var memoryStream = new MemoryStream()) - { - await image.SaveAsPngAsync(memoryStream); - return memoryStream.ToArray(); - } - } - - - /// - /// Converts to image memory stream. - /// - /// The image tensor. - /// - public static Stream ToImageStream(this DenseTensor imageTensor, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - using (var image = imageTensor.ToImage(imageNormalizeType)) - { - var memoryStream = new MemoryStream(); - image.SaveAsPng(memoryStream); - return memoryStream; - } - } - - - /// - /// Converts to image memory stream. - /// - /// The image tensor. - /// - public static async Task ToImageStreamAsync(this DenseTensor imageTensor, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - using (var image = imageTensor.ToImage(imageNormalizeType)) - { - var memoryStream = new MemoryStream(); - await image.SaveAsPngAsync(memoryStream); - return memoryStream; - } - } - /// /// Converts to image mask. /// /// The image tensor. /// - public static Image ToImageMask(this DenseTensor imageTensor) + public static OnnxImage ToImageMask(this DenseTensor imageTensor) { var width = imageTensor.Dimensions[3]; var height = imageTensor.Dimensions[2]; @@ -149,381 +25,10 @@ public static Image ToImageMask(this DenseTensor imageTensor) result[x, y] = new L8((byte)(imageTensor[0, 0, y, x] * 255.0f)); } } - return result.CloneAs(); - } - } - - - /// - /// Converts to Image. - /// - /// The ort value. - /// - public static Image ToImage(this OrtValue ortValue, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne) - { - return ortValue.ToDenseTensor().ToImage(normalizeType); - } - - - /// - /// Converts to image. - /// - /// The image tensor. - /// - public static Image ToImage(this DenseTensor imageTensor, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne) - { - var height = imageTensor.Dimensions[2]; - var width = imageTensor.Dimensions[3]; - var result = new Image(width, height); - for (var y = 0; y < height; y++) - { - for (var x = 0; x < width; x++) - { - if (normalizeType == ImageNormalizeType.ZeroToOne) - { - result[x, y] = new Rgba32( - DenormalizeZeroToOneToByte(imageTensor, 0, y, x), - DenormalizeZeroToOneToByte(imageTensor, 1, y, x), - DenormalizeZeroToOneToByte(imageTensor, 2, y, x)); - } - else - { - result[x, y] = new Rgba32( - DenormalizeOneToOneToByte(imageTensor, 0, y, x), - DenormalizeOneToOneToByte(imageTensor, 1, y, x), - DenormalizeOneToOneToByte(imageTensor, 2, y, x)); - } - } - } - return result; - } - - #endregion - - #region Image To Tensor - - - /// - /// Converts to DenseTensor. - /// - /// The image. - /// The dimensions. - /// - public static DenseTensor ToDenseTensor(this Image image, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) - { - var dimensions = new[] { 1, channels, image.Height, image.Width }; - return normalizeType == ImageNormalizeType.ZeroToOne - ? NormalizeToZeroToOne(image, dimensions) - : NormalizeToOneToOne(image, dimensions); - } - - - /// - /// Converts to InputImage to DenseTensor. - /// - /// The image data. - /// Type of the image normalize. - /// - public static async Task> ToDenseTensorAsync(this InputImage imageData, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - if (!string.IsNullOrEmpty(imageData.ImageBase64)) - return await TensorFromBase64Async(imageData.ImageBase64, default, default, imageNormalizeType); - if (imageData.ImageBytes != null) - return await TensorFromBytesAsync(imageData.ImageBytes, default, default, imageNormalizeType); - if (imageData.ImageStream != null) - return await TensorFromStreamAsync(imageData.ImageStream, default, default, imageNormalizeType); - if (imageData.ImageTensor != null) - return imageData.ImageTensor.ToDenseTensor(); // Note: Tensor Copy // TODO: Reshape to dimensions - - return await TensorFromImageAsync(imageData.Image, default, default, imageNormalizeType); - } - - - /// - /// Converts to InputImage to DenseTensor. - /// - /// The image data. - /// Height of the resize. - /// Width of the resize. - /// Type of the image normalize. - /// - public static async Task> ToDenseTensorAsync(this InputImage imageData, int resizeHeight, int resizeWidth, ImageNormalizeType imageNormalizeType = ImageNormalizeType.OneToOne) - { - if (!string.IsNullOrEmpty(imageData.ImageBase64)) - return await TensorFromBase64Async(imageData.ImageBase64, resizeHeight, resizeWidth, imageNormalizeType); - if (imageData.ImageBytes != null) - return await TensorFromBytesAsync(imageData.ImageBytes, resizeHeight, resizeWidth, imageNormalizeType); - if (imageData.ImageStream != null) - return await TensorFromStreamAsync(imageData.ImageStream, resizeHeight, resizeWidth, imageNormalizeType); - if (imageData.ImageTensor != null) - return imageData.ImageTensor.ToDenseTensor(); // Note: Tensor Copy // TODO: Reshape to dimensions - - return await TensorFromImageAsync(imageData.Image, resizeHeight, resizeWidth, imageNormalizeType); - } - - - /// - /// Tensor from image. - /// - /// The image. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static DenseTensor TensorFromImage(Image image, int height, int width, ImageNormalizeType imageNormalizeType) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); - } - - - /// - /// Tensor from image. - /// - /// The image. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static Task> TensorFromImageAsync(Image image, int height, int width, ImageNormalizeType imageNormalizeType) - { - return Task.Run(() => TensorFromImage(image, height, width, imageNormalizeType)); - } - - - /// - /// Tensor from image file. - /// - /// The filename. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static DenseTensor TensorFromFile(string filename, int height, int width, ImageNormalizeType imageNormalizeType) - { - using (var image = ImageSharp.Load(filename)) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); - } - } - - - /// - /// Tensor from image file. - /// - /// The filename. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static async Task> TensorFromFileAsync(string filename, int height, int width, ImageNormalizeType imageNormalizeType) - { - using (var image = await ImageSharp.LoadAsync(filename)) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); - } - } - - - /// - /// Tensor from base64 image. - /// - /// The base64 image. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static DenseTensor TensorFromBase64(string base64Image, int height, int width, ImageNormalizeType imageNormalizeType) - { - return TensorFromBytes(Convert.FromBase64String(base64Image.Split(',')[1]), height, width, imageNormalizeType); - } - - - /// - /// Tensor from base64 image. - /// - /// The base64 image. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static async Task> TensorFromBase64Async(string base64Image, int height, int width, ImageNormalizeType imageNormalizeType) - { - return await TensorFromBytesAsync(Convert.FromBase64String(base64Image.Split(',')[1]), height, width, imageNormalizeType); - } - - - /// - /// Tensor from image bytes. - /// - /// The image bytes. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static DenseTensor TensorFromBytes(byte[] imageBytes, int height, int width, ImageNormalizeType imageNormalizeType) - { - using (var image = ImageSharp.Load(imageBytes)) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); - } - } - - - /// - /// Tensors from image bytes. - /// - /// The image bytes. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static async Task> TensorFromBytesAsync(byte[] imageBytes, int height, int width, ImageNormalizeType imageNormalizeType) - { - return await Task.Run(() => TensorFromBytes(imageBytes, height, width, imageNormalizeType)); - } - - - /// - /// Tensor from image stream. - /// - /// The image stream. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static DenseTensor TensorFromStream(Stream imageStream, int height, int width, ImageNormalizeType imageNormalizeType) - { - using (var image = ImageSharp.Load(imageStream)) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); - } - } - - - /// - /// Tensor from image stream. - /// - /// The image stream. - /// The height. - /// The width. - /// Type of the image normalize. - /// - private static async Task> TensorFromStreamAsync(Stream imageStream, int height, int width, ImageNormalizeType imageNormalizeType) - { - using (var image = await ImageSharp.LoadAsync(imageStream)) - { - if (height > 0 && width > 0) - image.Resize(height, width); - - return image.ToDenseTensor(imageNormalizeType); + return new OnnxImage(result.CloneAs()); } } - #endregion - - #region Normalize - - /// - /// Normalizes the pixels from 0-255 to 0-1 - /// - /// The image. - /// The dimensions. - /// - private static DenseTensor NormalizeToZeroToOne(Image image, ReadOnlySpan dimensions) - { - var width = dimensions[3]; - var height = dimensions[2]; - var channels = dimensions[1]; - var imageArray = new DenseTensor(new[] { 1, channels, height, width }); - image.ProcessPixelRows(img => - { - for (int x = 0; x < width; x++) - { - for (int y = 0; y < height; y++) - { - var pixelSpan = img.GetRowSpan(y); - imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f); - imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f); - imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f); - } - } - }); - return imageArray; - } - - - /// - /// Normalizes the pixels from 0-255 to 0-1 - /// - /// The image. - /// The dimensions. - /// - private static DenseTensor NormalizeToOneToOne(Image image, ReadOnlySpan dimensions) - { - var width = dimensions[3]; - var height = dimensions[2]; - var channels = dimensions[1]; - var imageArray = new DenseTensor(new[] { 1, channels, height, width }); - image.ProcessPixelRows(img => - { - for (int x = 0; x < width; x++) - { - for (int y = 0; y < height; y++) - { - var pixelSpan = img.GetRowSpan(y); - imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f; - imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f; - imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f; - } - } - }); - return imageArray; - } - - - /// - /// Denormalizes the pixels from 0 to 1 to 0-255 - /// - /// The image tensor. - /// The index. - /// The y. - /// The x. - /// - private static byte DenormalizeZeroToOneToByte(DenseTensor imageTensor, int index, int y, int x) - { - return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x], 0, 1) * 255); - } - - - /// - /// Denormalizes the pixels from -1 to 1 to 0-255 - /// - /// The image tensor. - /// The index. - /// The y. - /// The x. - /// - private static byte DenormalizeOneToOneToByte(Tensor imageTensor, int index, int y, int x) - { - return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x] / 2 + 0.5, 0, 1) * 255); - } - - #endregion } public enum ImageNormalizeType diff --git a/OnnxStack.Core/Image/InputImage.cs b/OnnxStack.Core/Image/InputImage.cs deleted file mode 100644 index f27dcedd..00000000 --- a/OnnxStack.Core/Image/InputImage.cs +++ /dev/null @@ -1,105 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; -using System.IO; -using System.Text.Json.Serialization; -using System.Threading.Tasks; - -namespace OnnxStack.Core.Image -{ - public class InputImage - { - /// - /// Initializes a new instance of the class. - /// - public InputImage() { } - - /// - /// Initializes a new instance of the class. - /// - /// The image. - public InputImage(Image image) => Image = image; - - /// - /// Initializes a new instance of the class. - /// - /// The image in base64 format. - public InputImage(string imageBase64) => ImageBase64 = imageBase64; - - /// - /// Initializes a new instance of the class. - /// - /// The image bytes. - public InputImage(byte[] imageBytes) => ImageBytes = imageBytes; - - /// - /// Initializes a new instance of the class. - /// - /// The image stream. - public InputImage(Stream imageStream) => ImageStream = imageStream; - - /// - /// Initializes a new instance of the class. - /// - /// The image tensor. - public InputImage(DenseTensor imageTensor) => ImageTensor = imageTensor; - - /// - /// Gets the image. - /// - [JsonIgnore] - public Image Image { get; set; } - - - /// - /// Gets the image base64 string. - /// - public string ImageBase64 { get; set; } - - - /// - /// Gets the image bytes. - /// - [JsonIgnore] - public byte[] ImageBytes { get; set; } - - - /// - /// Gets the image stream. - /// - [JsonIgnore] - public Stream ImageStream { get; set; } - - - /// - /// Gets the image tensor. - /// - [JsonIgnore] - public DenseTensor ImageTensor { get; set; } - - - /// - /// Gets a value indicating whether this instance has image. - /// - /// - /// true if this instance has image; otherwise, false. - /// - [JsonIgnore] - public bool HasImage => Image != null - || !string.IsNullOrEmpty(ImageBase64) - || ImageBytes != null - || ImageStream != null - || ImageTensor != null; - - - /// - /// Create an image from file - /// - /// The file path. - /// - public static async Task FromFileAsync(string filePath) - { - return new InputImage(await File.ReadAllBytesAsync(filePath)); - } - } -} diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs new file mode 100644 index 00000000..45f5cd3a --- /dev/null +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -0,0 +1,395 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using SixLabors.ImageSharp; +using SixLabors.ImageSharp.Formats.Png; +using SixLabors.ImageSharp.PixelFormats; +using SixLabors.ImageSharp.Processing; +using System; +using System.IO; +using System.Threading.Tasks; +using ImageSharp = SixLabors.ImageSharp.Image; + +namespace OnnxStack.Core.Image +{ + public class OnnxImage : IDisposable + { + private readonly Image _imageData; + + + /// + /// Initializes a new instance of the class. + /// + /// The image. + public OnnxImage(Image image) + { + _imageData = image.Clone(); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The filename. + public OnnxImage(string filename) + { + _imageData = ImageSharp.Load(filename); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The image bytes. + public OnnxImage(byte[] imageBytes) + { + _imageData = ImageSharp.Load(imageBytes); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The image stream. + public OnnxImage(Stream imageStream) + { + _imageData = ImageSharp.Load(imageStream); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The image tensor. + /// Type of the normalize. + public OnnxImage(DenseTensor imageTensor, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne) + { + var height = imageTensor.Dimensions[2]; + var width = imageTensor.Dimensions[3]; + _imageData = new Image(width, height); + for (var y = 0; y < height; y++) + { + for (var x = 0; x < width; x++) + { + if (normalizeType == ImageNormalizeType.ZeroToOne) + { + _imageData[x, y] = new Rgba32( + DenormalizeZeroToOneToByte(imageTensor, 0, y, x), + DenormalizeZeroToOneToByte(imageTensor, 1, y, x), + DenormalizeZeroToOneToByte(imageTensor, 2, y, x)); + } + else + { + _imageData[x, y] = new Rgba32( + DenormalizeOneToOneToByte(imageTensor, 0, y, x), + DenormalizeOneToOneToByte(imageTensor, 1, y, x), + DenormalizeOneToOneToByte(imageTensor, 2, y, x)); + } + } + } + } + + + /// + /// Gets the height. + /// + public int Height => _imageData.Height; + + /// + /// Gets the width. + /// + public int Width => _imageData.Width; + + /// + /// Gets a value indicating whether this instance has image. + /// + /// + /// true if this instance has image; otherwise, false. + /// + public bool HasImage + { + get { return _imageData != null; } + } + + + /// + /// Gets the image. + /// + /// + public Image GetImage() + { + return _imageData; + } + + + /// + /// Gets the image as base64. + /// + /// + public string GetImageBase64() + { + return _imageData?.ToBase64String(PngFormat.Instance); + } + + + /// + /// Gets the image as bytes. + /// + /// + public byte[] GetImageBytes() + { + using (var memoryStream = new MemoryStream()) + { + _imageData.SaveAsPng(memoryStream); + return memoryStream.ToArray(); + } + } + + + /// + /// Gets the image as stream. + /// + /// + public Stream GetImageStream() + { + var memoryStream = new MemoryStream(); + _imageData.SaveAsPng(memoryStream); + return memoryStream; + } + + + /// + /// Gets the image as tensor. + /// + /// Type of the normalize. + /// The channels. + /// + public DenseTensor GetImageTensor(ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + { + var dimensions = new[] { 1, channels, Height, Width }; + return normalizeType == ImageNormalizeType.ZeroToOne + ? NormalizeToZeroToOne(dimensions) + : NormalizeToOneToOne(dimensions); + } + + + /// + /// Gets the image as tensor. + /// + /// The height. + /// The width. + /// Type of the normalize. + /// The channels. + /// + public DenseTensor GetImageTensor(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + { + if (height > 0 && width > 0) + Resize(height, width); + + return GetImageTensor(normalizeType, channels); + } + + + /// + /// Gets the image as tensor asynchronously. + /// + /// The height. + /// The width. + /// Type of the normalize. + /// The channels. + /// + public Task> GetImageTensorAsync(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + { + return Task.Run(() => GetImageTensor(height, width, normalizeType, channels)); + } + + + /// + /// Resizes the image. + /// + /// The height. + /// The width. + /// The resize mode. + public void Resize(int height, int width, ResizeMode resizeMode = ResizeMode.Crop) + { + _imageData.Mutate(x => + { + x.Resize(new ResizeOptions + { + Size = new Size(width, height), + Mode = resizeMode, + Sampler = KnownResamplers.Lanczos8, + Compand = true + }); + }); + } + + + /// + /// Saves the specified image to file. + /// + /// The filename. + public void Save(string filename) + { + _imageData?.SaveAsPng(filename); + } + + + /// + /// Saves the specified image to file asynchronously. + /// + /// The filename. + /// + public Task SaveAsync(string filename) + { + return _imageData?.SaveAsPngAsync(filename); + } + + + /// + /// Saves the specified image to stream. + /// + /// The stream. + public void Save(Stream stream) + { + _imageData?.SaveAsPng(stream); + } + + + /// + /// Saves the specified image to stream asynchronously. + /// + /// The stream. + /// + public Task SaveAsync(Stream stream) + { + return _imageData?.SaveAsPngAsync(stream); + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public void Dispose() + { + _imageData?.Dispose(); + } + + + /// + /// Normalizes the pixels from 0-255 to 0-1 + /// + /// The image. + /// The dimensions. + /// + private DenseTensor NormalizeToZeroToOne(ReadOnlySpan dimensions) + { + var width = dimensions[3]; + var height = dimensions[2]; + var channels = dimensions[1]; + var imageArray = new DenseTensor(new[] { 1, channels, height, width }); + _imageData.ProcessPixelRows(img => + { + for (int x = 0; x < width; x++) + { + for (int y = 0; y < height; y++) + { + var pixelSpan = img.GetRowSpan(y); + imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f); + imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f); + imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f); + } + } + }); + return imageArray; + } + + + /// + /// Normalizes the pixels from 0-255 to 0-1 + /// + /// The image. + /// The dimensions. + /// + private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions) + { + var width = dimensions[3]; + var height = dimensions[2]; + var channels = dimensions[1]; + var imageArray = new DenseTensor(new[] { 1, channels, height, width }); + _imageData.ProcessPixelRows(img => + { + for (int x = 0; x < width; x++) + { + for (int y = 0; y < height; y++) + { + var pixelSpan = img.GetRowSpan(y); + imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f; + imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f; + imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f; + } + } + }); + return imageArray; + } + + + /// + /// Denormalizes the pixels from 0 to 1 to 0-255 + /// + /// The image tensor. + /// The index. + /// The y. + /// The x. + /// + private static byte DenormalizeZeroToOneToByte(DenseTensor imageTensor, int index, int y, int x) + { + return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x], 0, 1) * 255); + } + + + /// + /// Denormalizes the pixels from -1 to 1 to 0-255 + /// + /// The image tensor. + /// The index. + /// The y. + /// The x. + /// + private static byte DenormalizeOneToOneToByte(Tensor imageTensor, int index, int y, int x) + { + return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x] / 2 + 0.5, 0, 1) * 255); + } + + + /// + /// Create OnnxImage from file asynchronously + /// + /// The file path. + /// + public static async Task FromFileAsync(string filePath) + { + return new OnnxImage(await ImageSharp.LoadAsync(filePath)); + } + + + /// + /// Create OnnxImage from stream asynchronously + /// + /// The image stream. + /// + public static async Task FromStreamAsync(Stream imageStream) + { + return new OnnxImage(await ImageSharp.LoadAsync(imageStream)); + } + + + /// + /// Create OnnxImage from bytes asynchronously + /// + /// The image stream. + /// + public static async Task FromBytesAsync(byte[] imageBytes) + { + return await Task.Run(() => new OnnxImage(imageBytes)); + } + } +} diff --git a/OnnxStack.Core/Video/Extensions.cs b/OnnxStack.Core/Video/Extensions.cs index d32a2c96..dbabfa88 100644 --- a/OnnxStack.Core/Video/Extensions.cs +++ b/OnnxStack.Core/Video/Extensions.cs @@ -26,7 +26,7 @@ public static IEnumerable ToVideoFramesAsBytes(this DenseTensor v { foreach (var frame in videoTensor.ToVideoFrames()) { - yield return frame.ToImageBytes(); + yield return new OnnxImage(frame).GetImageBytes(); } } @@ -34,16 +34,16 @@ public static async IAsyncEnumerable ToVideoFramesAsBytesAsync(this Dens { foreach (var frame in videoTensor.ToVideoFrames()) { - yield return await frame.ToImageBytesAsync(); + yield return new OnnxImage(frame).GetImageBytes(); } } - public static IEnumerable> ToVideoFramesAsImage(this DenseTensor videoTensor) - { - foreach (var frame in videoTensor.ToVideoFrames()) - { - yield return frame.ToImage(); - } - } + //public static IEnumerable> ToVideoFramesAsImage(this DenseTensor videoTensor) + //{ + // foreach (var frame in videoTensor.ToVideoFrames()) + // { + // yield return frame.ToImage(); + // } + //} } } diff --git a/OnnxStack.Core/Video/VideoFrame.cs b/OnnxStack.Core/Video/VideoFrame.cs index acbe0f05..835f8d10 100644 --- a/OnnxStack.Core/Video/VideoFrame.cs +++ b/OnnxStack.Core/Video/VideoFrame.cs @@ -4,6 +4,6 @@ namespace OnnxStack.Core.Video { public record VideoFrame(byte[] Frame) { - public InputImage ExtraFrame { get; set; } + public OnnxImage ExtraFrame { get; set; } } } diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs index 09075181..1d80ef8e 100644 --- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs +++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs @@ -54,10 +54,10 @@ public async Task UnloadAsync() /// /// The input image. /// - public async Task RunAsync(InputImage inputImage, CancellationToken cancellationToken = default) + public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Extracting image feature..."); - var controlImage = await inputImage.ToDenseTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); + var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); var metadata = await _featureExtractorModel.GetMetadataAsync(); cancellationToken.ThrowIfCancellationRequested(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) @@ -77,7 +77,7 @@ public async Task RunAsync(InputImage inputImage, CancellationToken var maskImage = resultTensor.ToImageMask(); //await maskImage.SaveAsPngAsync("D:\\Mask.png"); _logger?.LogEnd("Extracting image feature complete.", timestamp); - return new InputImage(maskImage); + return maskImage; } } } @@ -96,8 +96,8 @@ public async Task RunAsync(VideoFrames videoFrames, CancellationTok foreach (var videoFrame in videoFrames.Frames) { - var image = new InputImage(videoFrame.Frame); - var controlImage = await image.ToDenseTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); + var image = new OnnxImage(videoFrame.Frame); + var controlImage = await image.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { inferenceParameters.AddInputTensor(controlImage); @@ -113,7 +113,7 @@ public async Task RunAsync(VideoFrames videoFrames, CancellationTok resultTensor.NormalizeMinMax(); var maskImage = resultTensor.ToImageMask(); - videoFrame.ExtraFrame = new InputImage(maskImage); + videoFrame.ExtraFrame = maskImage; } } } diff --git a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs b/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs index f0ad1fdc..897a905f 100644 --- a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs +++ b/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs @@ -1,4 +1,5 @@ using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; using OnnxStack.ImageUpscaler.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; @@ -17,7 +18,7 @@ internal static class ImageExtensions /// Maximum size of the tile. /// The scale factor. /// - public static List GenerateTiles(this Image imageSource, int sampleSize, int scaleFactor) + public static List GenerateTiles(this OnnxImage imageSource, int sampleSize, int scaleFactor) { var tiles = new List(); var tileSizeX = Math.Min(sampleSize, imageSource.Width); @@ -47,11 +48,11 @@ public static List GenerateTiles(this Image imageSource, int /// The source image. /// The source area. /// - public static Image ExtractTile(this Image sourceImage, Rectangle sourceArea) + public static OnnxImage ExtractTile(this OnnxImage sourceImage, Rectangle sourceArea) { var height = sourceArea.Height; var targetImage = new Image(sourceArea.Width, sourceArea.Height); - sourceImage.ProcessPixelRows(targetImage, (sourceAccessor, targetAccessor) => + sourceImage.GetImage().ProcessPixelRows(targetImage, (sourceAccessor, targetAccessor) => { for (int i = 0; i < height; i++) { @@ -60,7 +61,7 @@ public static Image ExtractTile(this Image sourceImage, Rectangl sourceRow.Slice(sourceArea.X, sourceArea.Width).CopyTo(targetRow); } }); - return targetImage; + return new OnnxImage(targetImage); } diff --git a/OnnxStack.ImageUpscaler/Models/ImageTile.cs b/OnnxStack.ImageUpscaler/Models/ImageTile.cs index 39a402aa..12502aab 100644 --- a/OnnxStack.ImageUpscaler/Models/ImageTile.cs +++ b/OnnxStack.ImageUpscaler/Models/ImageTile.cs @@ -1,12 +1,11 @@ -using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; namespace OnnxStack.ImageUpscaler.Models { public record ImageTile { - public Image Image { get; set; } + public OnnxImage Image { get; set; } public Rectangle Destination { get; set; } } } diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index f0f5fc3a..964d259f 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -29,7 +29,7 @@ public class ImageUpscalePipeline /// The name. /// The upscale model. /// The logger. - public ImageUpscalePipeline(string name, UpscaleModel upscaleModel, ILogger logger = default) + public ImageUpscalePipeline(string name, UpscaleModel upscaleModel, ILogger logger = default) { _name = name; _logger = logger; @@ -69,39 +69,35 @@ public async Task UnloadAsync() /// The input image. /// The cancellation token. /// - public async Task> RunAsync(InputImage inputImage, CancellationToken cancellationToken = default) + public async Task> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - using (var image = await inputImage.ToImageAsync()) + var upscaleInput = CreateInputParams(inputImage, _upscaleModel.SampleSize, _upscaleModel.ScaleFactor); + var metadata = await _upscaleModel.GetMetadataAsync(); + + var outputTensor = new DenseTensor(new[] { 1, _upscaleModel.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); + foreach (var imageTile in upscaleInput.ImageTiles) { - var upscaleInput = CreateInputParams(image, _upscaleModel.SampleSize, _upscaleModel.ScaleFactor); - var metadata = await _upscaleModel.GetMetadataAsync(); + cancellationToken.ThrowIfCancellationRequested(); - var outputTensor = new DenseTensor(new[] { 1, _upscaleModel.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); - foreach (var imageTile in upscaleInput.ImageTiles) + var outputDimension = new[] { 1, _upscaleModel.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; + var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { - cancellationToken.ThrowIfCancellationRequested(); + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddOutputBuffer(outputDimension); - var outputDimension = new[] { 1, _upscaleModel.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; - var inputTensor = imageTile.Image.ToDenseTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + var results = await _upscaleModel.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await _upscaleModel.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); - } + outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); } } - - return outputTensor; } + return outputTensor; } - private static UpscaleInput CreateInputParams(Image imageSource, int maxTileSize, int scaleFactor) + private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor) { var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor); var width = imageSource.Width * scaleFactor; diff --git a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs b/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs index 5d090884..ba71b7e9 100644 --- a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs +++ b/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs @@ -42,7 +42,7 @@ public interface IUpscaleService /// The model options. /// The input image. /// - Task> GenerateAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default); + Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); /// /// Generates the upscaled image. @@ -50,7 +50,7 @@ public interface IUpscaleService /// The model options. /// The input image. /// - Task> GenerateAsImageAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default); + Task GenerateAsImageAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); /// @@ -59,7 +59,7 @@ public interface IUpscaleService /// The model options. /// The input image. /// - Task GenerateAsByteAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default); + Task GenerateAsByteAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); /// @@ -68,7 +68,7 @@ public interface IUpscaleService /// The model options. /// The input image. /// - Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default); + Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); /// diff --git a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs b/OnnxStack.ImageUpscaler/Services/UpscaleService.cs index a46a4880..8b0c3e37 100644 --- a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs +++ b/OnnxStack.ImageUpscaler/Services/UpscaleService.cs @@ -6,7 +6,6 @@ using OnnxStack.Core.Services; using OnnxStack.Core.Video; using OnnxStack.ImageUpscaler.Common; -using OnnxStack.ImageUpscaler.Config; using OnnxStack.ImageUpscaler.Extensions; using OnnxStack.ImageUpscaler.Models; using SixLabors.ImageSharp; @@ -86,7 +85,7 @@ public bool IsModelLoaded(UpscaleModelSet modelOptions) /// The model options. /// The input image. /// - public async Task> GenerateAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default) + public async Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) { return await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); } @@ -98,10 +97,10 @@ public async Task> GenerateAsync(UpscaleModelSet modelOptions /// The model options. /// The input image. /// - public async Task> GenerateAsImageAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default) + public async Task GenerateAsImageAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) { var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - return imageTensor.ToImage(ImageNormalizeType.ZeroToOne); + return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne); } @@ -111,13 +110,13 @@ public async Task> GenerateAsImageAsync(UpscaleModelSet modelOptio /// The model options. /// The input image. /// - public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default) + public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) { var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); using (var memoryStream = new MemoryStream()) - using (var image = imageTensor.ToImage(ImageNormalizeType.ZeroToOne)) + using (var image = new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne)) { - await image.SaveAsPngAsync(memoryStream); + await image.SaveAsync(memoryStream); return memoryStream.ToArray(); } } @@ -129,13 +128,13 @@ public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, Inpu /// The model options. /// The input image. /// - public async Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, InputImage inputImage, CancellationToken cancellationToken = default) + public async Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) { var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - using (var image = imageTensor.ToImage(ImageNormalizeType.ZeroToOne)) + using (var image = new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne)) { var memoryStream = new MemoryStream(); - await image.SaveAsPngAsync(memoryStream); + await image.SaveAsync(memoryStream); return memoryStream; } } @@ -180,9 +179,9 @@ public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, Vide cancellationToken.ThrowIfCancellationRequested(); using (var imageStream = new MemoryStream()) - using (var imageFrame = tensorFrame.ToImage(ImageNormalizeType.ZeroToOne)) + using (var imageFrame = new OnnxImage(tensorFrame, ImageNormalizeType.ZeroToOne)) { - await imageFrame.SaveAsPngAsync(imageStream); + await imageFrame.SaveAsync(imageStream); return imageStream.ToArray(); } })); @@ -211,39 +210,36 @@ public async Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, Vi /// /// The model options. /// The input image. - private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, InputImage inputImage, CancellationToken cancellationToken) + private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken) { if (!_modelSessions.TryGetValue(modelSet, out var modelSession)) throw new System.Exception("Model not loaded"); - using (var image = await inputImage.ToImageAsync()) - { + var upscaleInput = CreateInputParams(inputImage, modelSession.SampleSize, modelSession.ScaleFactor); + var metadata = await modelSession.GetMetadataAsync(); - var upscaleInput = CreateInputParams(image, modelSession.SampleSize, modelSession.ScaleFactor); - var metadata = await modelSession.GetMetadataAsync(); + var outputTensor = new DenseTensor(new[] { 1, modelSession.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); + foreach (var imageTile in upscaleInput.ImageTiles) + { + cancellationToken.ThrowIfCancellationRequested(); - var outputTensor = new DenseTensor(new[] { 1, modelSession.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); - foreach (var imageTile in upscaleInput.ImageTiles) + var outputDimension = new[] { 1, modelSession.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; + var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { - cancellationToken.ThrowIfCancellationRequested(); + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddOutputBuffer(outputDimension); - var outputDimension = new[] { 1, modelSession.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; - var inputTensor = imageTile.Image.ToDenseTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + var results = await modelSession.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await modelSession.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); - } + outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); } } - - return outputTensor; } + + return outputTensor; + } @@ -265,14 +261,14 @@ public async Task>> GenerateInternalAsync(UpscaleModelSe var outputTensors = new List>(); foreach (var frame in videoFrames.Frames) { - using (var imageFrame = Image.Load(frame.Frame)) + using (var imageFrame = new OnnxImage(frame.Frame)) { var input = CreateInputParams(imageFrame, modelSession.SampleSize, modelSession.ScaleFactor); var outputDimension = new[] { 1, modelSession.Channels, 0, 0 }; var outputTensor = new DenseTensor(new[] { 1, modelSession.Channels, input.OutputHeight, input.OutputWidth }); foreach (var imageTile in input.ImageTiles) { - var inputTensor = imageTile.Image.ToDenseTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); + var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); outputDimension[2] = imageTile.Destination.Height; outputDimension[3] = imageTile.Destination.Width; using (var inferenceParameters = new OnnxInferenceParameters(metadata)) @@ -301,7 +297,7 @@ public async Task>> GenerateInternalAsync(UpscaleModelSe /// Maximum size of the tile. /// The scale factor. /// - private static UpscaleInput CreateInputParams(Image imageSource, int maxTileSize, int scaleFactor) + private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor) { var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor); var width = imageSource.Width * scaleFactor; diff --git a/OnnxStack.StableDiffusion/Common/BatchResult.cs b/OnnxStack.StableDiffusion/Common/BatchResult.cs index ef4162e5..b682ac2b 100644 --- a/OnnxStack.StableDiffusion/Common/BatchResult.cs +++ b/OnnxStack.StableDiffusion/Common/BatchResult.cs @@ -1,7 +1,7 @@ -using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; using OnnxStack.StableDiffusion.Config; namespace OnnxStack.StableDiffusion.Common { - public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor ImageResult); + public record BatchResult(SchedulerOptions SchedulerOptions, OnnxImage ImageResult); } diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index 15f45961..49f368dc 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -16,11 +16,11 @@ public class PromptOptions [StringLength(512)] public string NegativePrompt { get; set; } - public InputImage InputImage { get; set; } + public OnnxImage InputImage { get; set; } - public InputImage InputImageMask { get; set; } + public OnnxImage InputImageMask { get; set; } - public InputImage InputContolImage { get; set; } + public OnnxImage InputContolImage { get; set; } public VideoInput InputVideo { get; set; } diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs index dfcf4d68..2c8b9b57 100644 --- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs @@ -202,7 +202,7 @@ protected static DenseTensor CreateConditioningScaleTensor(float conditi /// protected async Task> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - return await promptOptions.InputContolImage.ToDenseTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); + return await promptOptions.InputContolImage.GetImageTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs index dc4e4fd4..3c0a53d7 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs @@ -200,7 +200,7 @@ protected static DenseTensor CreateConditioningScaleTensor(float conditi /// protected async Task> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - return await promptOptions.InputContolImage.ToDenseTensorAsync(schedulerOptions.Height, schedulerOptions.Width , ImageNormalizeType.ZeroToOne); + return await promptOptions.InputContolImage.GetImageTensorAsync(schedulerOptions.Height, schedulerOptions.Width , ImageNormalizeType.ZeroToOne); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs index b9dc0c70..0a0f8a47 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs @@ -60,7 +60,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); //TODO: Model Config, Channels var outputDimension = options.GetScaledDimension(); diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs index 211abd88..6b5d931c 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs @@ -59,7 +59,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimension = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs index 6d020866..45c4600e 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs @@ -161,7 +161,7 @@ public override async Task> DiffuseAsync(PromptOptions prompt /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimensions = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) @@ -192,7 +192,7 @@ protected override async Task> PrepareLatentsAsync(PromptOpti /// private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var mask = promptOptions.InputImageMask.ToImage()) + using (var mask = promptOptions.InputImageMask.GetImage()) { // Prepare the mask int width = schedulerOptions.GetScaledWidth(); diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs index a077de6c..8ddd1b08 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs @@ -204,7 +204,7 @@ protected static DenseTensor CreateConditioningScaleTensor(float conditi /// protected async Task> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - return await promptOptions.InputContolImage.ToDenseTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); + return await promptOptions.InputContolImage.GetImageTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs index 28dde8e2..fa751a12 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs @@ -60,7 +60,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimension = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs index 42f8110c..068ae2c3 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ImageDiffuser.cs @@ -61,7 +61,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimension = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs index c18b1eac..509922a6 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs @@ -156,7 +156,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimensions = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) @@ -187,7 +187,7 @@ protected override async Task> PrepareLatentsAsync(PromptOpti /// private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var mask = promptOptions.InputImageMask.ToImage()) + using (var mask = promptOptions.InputImageMask.GetImage()) { // Prepare the mask int width = schedulerOptions.GetScaledWidth(); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs index c39b7728..cef8ecdb 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs @@ -196,7 +196,7 @@ protected static DenseTensor CreateConditioningScaleTensor(float conditi /// protected async Task> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - return await promptOptions.InputContolImage.ToDenseTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); + return await promptOptions.InputContolImage.GetImageTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs index f6a887f0..e8a1edb4 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetImageDiffuser.cs @@ -60,7 +60,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimension = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs index 7cad1190..b315ae15 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs @@ -60,7 +60,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimension = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs index 36a1eada..c68c38a6 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs @@ -129,7 +129,7 @@ public override async Task> DiffuseAsync(PromptOptions prompt /// private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var imageMask = promptOptions.InputImageMask.ToImage()) + using (var imageMask = promptOptions.InputImageMask.GetImage()) { var width = schedulerOptions.GetScaledWidth(); var height = schedulerOptions.GetScaledHeight(); @@ -173,8 +173,8 @@ private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOpt /// private async Task> PrepareImageMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var image = await promptOptions.InputImage.ToImageAsync()) - using (var mask = await promptOptions.InputImageMask.ToImageAsync()) + using (var image = promptOptions.InputImage.GetImage()) + using (var mask = promptOptions.InputImageMask.GetImage()) { int width = schedulerOptions.Width; int height = schedulerOptions.Height; diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs index 554894e5..d821493b 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs @@ -150,7 +150,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); var outputDimensions = options.GetScaledDimension(); var metadata = await _vaeEncoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) @@ -181,7 +181,7 @@ protected override async Task> PrepareLatentsAsync(PromptOpti /// private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var mask = promptOptions.InputImageMask.ToImage()) + using (var mask = promptOptions.InputImageMask.GetImage()) { // Prepare the mask int width = schedulerOptions.GetScaledWidth(); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs index b803563a..983dbfc2 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs @@ -206,7 +206,7 @@ protected static DenseTensor CreateConditioningScaleTensor(float conditi /// protected async Task> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - return await promptOptions.InputContolImage.ToDenseTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); + return await promptOptions.InputContolImage.GetImageTensorAsync(schedulerOptions.Height, schedulerOptions.Width, ImageNormalizeType.ZeroToOne); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs index 9838a4d7..38e8e444 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs @@ -61,7 +61,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); //TODO: Model Config, Channels var outputDimension = options.GetScaledDimension(); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs index fa2e2c0d..d0de8251 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ImageDiffuser.cs @@ -61,7 +61,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); //TODO: Model Config, Channels var outputDimension = options.GetScaledDimension(); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs index 57bdd363..04795569 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs @@ -158,7 +158,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// protected override async Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) { - var imageTensor = await prompt.InputImage.ToDenseTensorAsync(options.Height, options.Width); + var imageTensor = await prompt.InputImage.GetImageTensorAsync(options.Height, options.Width); //TODO: Model Config, Channels var outputDimensions = options.GetScaledDimension(); @@ -191,7 +191,7 @@ protected override async Task> PrepareLatentsAsync(PromptOpti /// private DenseTensor PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions) { - using (var mask = promptOptions.InputImageMask.ToImage()) + using (var mask = promptOptions.InputImageMask.GetImage()) { // Prepare the mask int width = schedulerOptions.GetScaledWidth(); diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs index a0b7f18b..26c0d6b2 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs @@ -162,13 +162,13 @@ protected async Task> DiffuseVideoAsync(IDiffuser diffuser, P { // ControlNetImage uses frame as input image if (promptOptions.DiffuserType == DiffuserType.ControlNetImage) - promptOptions.InputImage = new InputImage(videoFrame.Frame); + promptOptions.InputImage = new OnnxImage(videoFrame.Frame); promptOptions.InputContolImage = videoFrame.ExtraFrame; } else { - promptOptions.InputImage = new InputImage(videoFrame.Frame); + promptOptions.InputImage = new OnnxImage(videoFrame.Frame); } var frameResultTensor = await diffuser.DiffuseAsync(promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken); diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs index 2e8a6fca..5d3b809f 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs @@ -2,6 +2,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; using OnnxStack.Core.Config; +using OnnxStack.Core.Image; using OnnxStack.Core.Model; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; @@ -230,7 +231,7 @@ public override async IAsyncEnumerable RunBatchAsync(BatchOptions b ? await DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken) : await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken); - yield return new BatchResult(batchSchedulerOption, tensorResult); + yield return new BatchResult(batchSchedulerOption, new OnnxImage(tensorResult)); batchIndex++; } diff --git a/OnnxStack.UI/Services/IStableDiffusionService.cs b/OnnxStack.UI/Services/IStableDiffusionService.cs index 085dcf2f..73c934c4 100644 --- a/OnnxStack.UI/Services/IStableDiffusionService.cs +++ b/OnnxStack.UI/Services/IStableDiffusionService.cs @@ -1,12 +1,9 @@ using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Config; +using OnnxStack.Core.Image; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; using System; using System.Collections.Generic; -using System.IO; using System.Threading; using System.Threading.Tasks; @@ -71,37 +68,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - Task> GenerateAsImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - Task GenerateAsBytesAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - Task GenerateAsStreamAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -114,41 +81,5 @@ public interface IStableDiffusionService /// The cancellation token. /// IAsyncEnumerable GenerateBatchAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - IAsyncEnumerable> GenerateBatchAsImageAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - IAsyncEnumerable GenerateBatchAsBytesAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); - - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - IAsyncEnumerable GenerateBatchAsStreamAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.UI/Services/StableDiffusionService.cs b/OnnxStack.UI/Services/StableDiffusionService.cs index f205a4d8..095a4ea6 100644 --- a/OnnxStack.UI/Services/StableDiffusionService.cs +++ b/OnnxStack.UI/Services/StableDiffusionService.cs @@ -129,62 +129,15 @@ public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions) /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { - return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); - } - - - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - public async Task> GenerateAsImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) - { - return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) - .ContinueWith(t => t.Result.ToImage(), cancellationToken) + return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken) + .ContinueWith(t => new OnnxImage(t.Result), cancellationToken) .ConfigureAwait(false); } - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - public async Task GenerateAsBytesAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) - { - var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); - if (!prompt.HasInputVideo) - return generateResult.ToImageBytes(); - return await GenerateVideoResultAsBytesAsync(generateResult, prompt.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); - } - - - /// - /// Generates the StableDiffusion image using the prompt and options provided. - /// - /// The prompt. - /// The Scheduler options. - /// The callback used to provide progess of the current InferenceSteps. - /// The cancellation token. - /// The diffusion result as - public async Task GenerateAsStreamAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) - { - var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); - if (!prompt.HasInputVideo) - return generateResult.ToImageStream(); - - return await GenerateVideoResultAsStreamAsync(generateResult, prompt.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); - } /// @@ -203,65 +156,11 @@ public IAsyncEnumerable GenerateBatchAsync(ModelOptions modelOption } - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - public async IAsyncEnumerable> GenerateBatchAsImageAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) - yield return result.ImageResult.ToImage(); - } - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - public async IAsyncEnumerable GenerateBatchAsBytesAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) - { - if (!promptOptions.HasInputVideo) - yield return result.ImageResult.ToImageBytes(); - - yield return await GenerateVideoResultAsBytesAsync(result.ImageResult, promptOptions.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); - } - } - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - public async IAsyncEnumerable GenerateBatchAsStreamAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) - { - if (!promptOptions.HasInputVideo) - yield return result.ImageResult.ToImageStream(); - yield return await GenerateVideoResultAsStreamAsync(result.ImageResult, promptOptions.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); - } - } /// diff --git a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs index f8793562..1821680f 100644 --- a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs +++ b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs @@ -220,7 +220,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsBytesAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { @@ -330,14 +330,8 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im DiffuserType = SelectedModel.ModelSet.Diffusers.Contains(DiffuserType.ImageInpaint) ? DiffuserType.ImageInpaint : DiffuserType.ImageInpaintLegacy, - InputImage = new InputImage - { - ImageBytes = imageInput.Image.GetImageBytes() - }, - InputImageMask = new InputImage - { - ImageBytes = imageInputMask.Image.GetImageBytes() - } + InputImage = new OnnxImage(imageInput.Image.GetImageBytes()), + InputImageMask = new OnnxImage(imageInputMask.Image.GetImageBytes()) }; } @@ -350,9 +344,9 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im /// The scheduler options. /// The timestamp. /// - private Task GenerateResultAsync(byte[] imageBytes, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + private Task GenerateResultAsync(OnnxImage onnxImage, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) { - var image = Utils.CreateBitmap(imageBytes); + var image = Utils.CreateBitmap(onnxImage.GetImageBytes()); var imageResult = new ImageResult { diff --git a/OnnxStack.UI/Views/ImageToImageView.xaml.cs b/OnnxStack.UI/Views/ImageToImageView.xaml.cs index 332fe7db..9d394088 100644 --- a/OnnxStack.UI/Views/ImageToImageView.xaml.cs +++ b/OnnxStack.UI/Views/ImageToImageView.xaml.cs @@ -208,7 +208,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsBytesAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { @@ -310,9 +310,9 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im ? DiffuserType.ControlNet : DiffuserType.ControlNetImage; - var inputImage = default(InputImage); + var inputImage = default(OnnxImage); if (controlNetDiffuserType == DiffuserType.ControlNetImage) - inputImage = new InputImage(imageBytes); + inputImage = new OnnxImage(imageBytes); return new PromptOptions { @@ -320,7 +320,7 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im NegativePrompt = promptOptionsModel.NegativePrompt, DiffuserType = controlNetDiffuserType, InputImage = inputImage, - InputContolImage = new InputImage(imageBytes) + InputContolImage = new OnnxImage(imageBytes) }; } @@ -329,7 +329,7 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im Prompt = promptOptionsModel.Prompt, NegativePrompt = promptOptionsModel.NegativePrompt, DiffuserType = DiffuserType.ImageToImage, - InputImage = new InputImage(imageBytes) + InputImage = new OnnxImage(imageBytes) }; } @@ -342,9 +342,9 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Im /// The scheduler options. /// The timestamp. /// - private Task GenerateResultAsync(byte[] imageBytes, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + private Task GenerateResultAsync(OnnxImage onnxImage, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) { - var image = Utils.CreateBitmap(imageBytes); + var image = Utils.CreateBitmap(onnxImage.GetImageBytes()); var imageResult = new ImageResult { diff --git a/OnnxStack.UI/Views/TextToImageView.xaml.cs b/OnnxStack.UI/Views/TextToImageView.xaml.cs index 18144011..05df66a1 100644 --- a/OnnxStack.UI/Views/TextToImageView.xaml.cs +++ b/OnnxStack.UI/Views/TextToImageView.xaml.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using OnnxStack.Core.Image; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; @@ -180,7 +181,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsBytesAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { @@ -292,9 +293,9 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel) }; } - private Task GenerateResultAsync(byte[] imageBytes, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + private Task GenerateResultAsync(OnnxImage onnxImage, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) { - var image = Utils.CreateBitmap(imageBytes); + var image = Utils.CreateBitmap(onnxImage.GetImageBytes()); var imageResult = new ImageResult { diff --git a/OnnxStack.UI/Views/UpscaleView.xaml.cs b/OnnxStack.UI/Views/UpscaleView.xaml.cs index c088cf8a..78250634 100644 --- a/OnnxStack.UI/Views/UpscaleView.xaml.cs +++ b/OnnxStack.UI/Views/UpscaleView.xaml.cs @@ -199,7 +199,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var resultBytes = await _upscaleService.GenerateAsByteAsync(SelectedModel.ModelSet, new InputImage(InputImage.GetImageBytes()), _cancelationTokenSource.Token); + var resultBytes = await _upscaleService.GenerateAsByteAsync(SelectedModel.ModelSet, new OnnxImage(InputImage.GetImageBytes()), _cancelationTokenSource.Token); if (resultBytes != null) { var elapsed = Stopwatch.GetElapsedTime(timestamp).TotalSeconds; diff --git a/OnnxStack.UI/Views/VideoToVideoView.xaml.cs b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs index 81d3e69a..02af02c0 100644 --- a/OnnxStack.UI/Views/VideoToVideoView.xaml.cs +++ b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs @@ -225,7 +225,7 @@ private async Task Generate() var promptOptions = GetPromptOptions(PromptOptions, _videoFrames); var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsBytesAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultVideo = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultVideo != null) { @@ -354,8 +354,10 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Vi /// The scheduler options. /// The timestamp. /// - private async Task GenerateResultAsync(byte[] videoBytes, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + private async Task GenerateResultAsync(OnnxImage onnxImage, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) { + var videoBytes = onnxImage.GetImageBytes(); + var tempVideoFile = Path.Combine(".temp", $"VideoToVideo.mp4"); await File.WriteAllBytesAsync(tempVideoFile, videoBytes); var videoInfo = await _videoService.GetVideoInfoAsync(videoBytes); @@ -387,7 +389,7 @@ private Action ProgressCallback() if (progress.BatchTensor is not null) { - PreviewResult = Utils.CreateBitmap(progress.BatchTensor.ToImageBytes()); + PreviewResult = Utils.CreateBitmap(new OnnxImage( progress.BatchTensor).GetImageBytes()); PreviewSource = UpdatePreviewFrame(progress.BatchValue - 1); ProgressText = $"Video Frame {progress.BatchValue} of {_videoFrames.Frames.Count} complete"; } @@ -411,7 +413,7 @@ public BitmapImage UpdatePreviewFrame(int index) using (var memoryStream = new MemoryStream()) using (var frameImage = SixLabors.ImageSharp.Image.Load(frame.Frame)) { - frameImage.Resize(_schedulerOptions.Height, _schedulerOptions.Width); + //frameImage.Resize(_schedulerOptions.Height, _schedulerOptions.Width); frameImage.SaveAsPng(memoryStream); var image = new BitmapImage(); image.BeginInit(); From 21a1e211433c380bd688889c974fcf7afee039a4 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Mon, 12 Feb 2024 20:23:25 +1300 Subject: [PATCH 2/5] Update ImageUpscale to new API format --- OnnxStack.Console/Examples/UpscaleExample.cs | 6 +- OnnxStack.Console/Program.cs | 1 - .../Config/ImageUpscalerConfig.cs | 16 - .../Extensions/ImageExtensions.cs | 6 +- OnnxStack.ImageUpscaler/Models/ImageTile.cs | 2 +- .../Models/UpscaleInput.cs | 2 +- OnnxStack.ImageUpscaler/README.md | 30 +- OnnxStack.ImageUpscaler/Registration.cs | 63 ---- .../Services/IUpscaleService.cs | 100 ------ .../Services/UpscaleService.cs | 314 ------------------ OnnxStack.UI/App.xaml.cs | 4 +- OnnxStack.UI/MainWindow.xaml.cs | 2 - OnnxStack.UI/Services/IUpscaleService.cs | 56 ++++ OnnxStack.UI/Services/UpscaleService.cs | 154 +++++++++ .../UserControls/UpscalePickerControl.xaml.cs | 2 +- OnnxStack.UI/Views/UpscaleView.xaml.cs | 6 +- OnnxStack.UI/appsettings.json | 6 +- 17 files changed, 255 insertions(+), 515 deletions(-) delete mode 100644 OnnxStack.ImageUpscaler/Config/ImageUpscalerConfig.cs delete mode 100644 OnnxStack.ImageUpscaler/Registration.cs delete mode 100644 OnnxStack.ImageUpscaler/Services/IUpscaleService.cs delete mode 100644 OnnxStack.ImageUpscaler/Services/UpscaleService.cs create mode 100644 OnnxStack.UI/Services/IUpscaleService.cs create mode 100644 OnnxStack.UI/Services/UpscaleService.cs diff --git a/OnnxStack.Console/Examples/UpscaleExample.cs b/OnnxStack.Console/Examples/UpscaleExample.cs index 8c232ae9..15360de1 100644 --- a/OnnxStack.Console/Examples/UpscaleExample.cs +++ b/OnnxStack.Console/Examples/UpscaleExample.cs @@ -1,18 +1,14 @@ using OnnxStack.Core.Image; using OnnxStack.FeatureExtractor.Pipelines; -using OnnxStack.ImageUpscaler.Config; -using SixLabors.ImageSharp; namespace OnnxStack.Console.Runner { public sealed class UpscaleExample : IExampleRunner { private readonly string _outputDirectory; - private readonly ImageUpscalerConfig _configuration; - public UpscaleExample(ImageUpscalerConfig configuration) + public UpscaleExample() { - _configuration = configuration; _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(UpscaleExample)); Directory.CreateDirectory(_outputDirectory); } diff --git a/OnnxStack.Console/Program.cs b/OnnxStack.Console/Program.cs index 00b06408..d9337a19 100644 --- a/OnnxStack.Console/Program.cs +++ b/OnnxStack.Console/Program.cs @@ -19,7 +19,6 @@ static async Task Main(string[] _) // Add OnnxStack builder.Services.AddOnnxStack(); builder.Services.AddOnnxStackConfig(); - builder.Services.AddOnnxStackImageUpscaler(); // Add AppService builder.Services.AddHostedService(); diff --git a/OnnxStack.ImageUpscaler/Config/ImageUpscalerConfig.cs b/OnnxStack.ImageUpscaler/Config/ImageUpscalerConfig.cs deleted file mode 100644 index 1bd90f30..00000000 --- a/OnnxStack.ImageUpscaler/Config/ImageUpscalerConfig.cs +++ /dev/null @@ -1,16 +0,0 @@ -using OnnxStack.Common.Config; -using OnnxStack.Core; -using OnnxStack.ImageUpscaler.Common; -using System.Collections.Generic; - -namespace OnnxStack.ImageUpscaler.Config -{ - public class ImageUpscalerConfig : IConfigSection - { - public List ModelSets { get; set; } = new List(); - - public void Initialize() - { - } - } -} diff --git a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs b/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs index 897a905f..60aee828 100644 --- a/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs +++ b/OnnxStack.ImageUpscaler/Extensions/ImageExtensions.cs @@ -18,7 +18,7 @@ internal static class ImageExtensions /// Maximum size of the tile. /// The scale factor. /// - public static List GenerateTiles(this OnnxImage imageSource, int sampleSize, int scaleFactor) + internal static List GenerateTiles(this OnnxImage imageSource, int sampleSize, int scaleFactor) { var tiles = new List(); var tileSizeX = Math.Min(sampleSize, imageSource.Width); @@ -48,7 +48,7 @@ public static List GenerateTiles(this OnnxImage imageSource, int samp /// The source image. /// The source area. /// - public static OnnxImage ExtractTile(this OnnxImage sourceImage, Rectangle sourceArea) + internal static OnnxImage ExtractTile(this OnnxImage sourceImage, Rectangle sourceArea) { var height = sourceArea.Height; var targetImage = new Image(sourceArea.Width, sourceArea.Height); @@ -65,7 +65,7 @@ public static OnnxImage ExtractTile(this OnnxImage sourceImage, Rectangle source } - public static void ApplyImageTile(this DenseTensor imageTensor, DenseTensor tileTensor, Rectangle location) + internal static void ApplyImageTile(this DenseTensor imageTensor, DenseTensor tileTensor, Rectangle location) { var offsetY = location.Y; var offsetX = location.X; diff --git a/OnnxStack.ImageUpscaler/Models/ImageTile.cs b/OnnxStack.ImageUpscaler/Models/ImageTile.cs index 12502aab..57e4ac77 100644 --- a/OnnxStack.ImageUpscaler/Models/ImageTile.cs +++ b/OnnxStack.ImageUpscaler/Models/ImageTile.cs @@ -3,7 +3,7 @@ namespace OnnxStack.ImageUpscaler.Models { - public record ImageTile + internal record ImageTile { public OnnxImage Image { get; set; } public Rectangle Destination { get; set; } diff --git a/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs b/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs index eb6d4526..05d9f3a5 100644 --- a/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs +++ b/OnnxStack.ImageUpscaler/Models/UpscaleInput.cs @@ -2,7 +2,7 @@ namespace OnnxStack.ImageUpscaler.Models { - public record UpscaleInput(List ImageTiles, int OutputWidth, int OutputHeight); + internal record UpscaleInput(List ImageTiles, int OutputWidth, int OutputHeight); } diff --git a/OnnxStack.ImageUpscaler/README.md b/OnnxStack.ImageUpscaler/README.md index e2a9542a..096ac70d 100644 --- a/OnnxStack.ImageUpscaler/README.md +++ b/OnnxStack.ImageUpscaler/README.md @@ -1 +1,29 @@ -# OnnxStack.ImageUpsacler \ No newline at end of file +# OnnxStack.ImageUpscaler + +## Upscale Models +https://huggingface.co/wuminghao/swinir +https://huggingface.co/rocca/swin-ir-onnx +https://huggingface.co/Xenova/swin2SR-classical-sr-x2-64 +https://huggingface.co/Xenova/swin2SR-classical-sr-x4-64 + + +# Basic Example +```csharp +// Load Input Image +var inputImage = await OnnxImage.FromFileAsync("Input.png"); + +// Create Pipeline +var pipeline = ImageUpscalePipeline.CreatePipeline("003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", scaleFactor: 4); + +// Run pipeline +var result = await pipeline.RunAsync(inputImage); + +// Create Image from Tensor result +var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne); + +// Save Image File +await image.SaveAsync("Upscaled.png"); + +// Unload +await pipeline.UnloadAsync(); +``` \ No newline at end of file diff --git a/OnnxStack.ImageUpscaler/Registration.cs b/OnnxStack.ImageUpscaler/Registration.cs deleted file mode 100644 index 0b1a86ed..00000000 --- a/OnnxStack.ImageUpscaler/Registration.cs +++ /dev/null @@ -1,63 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using OnnxStack.Core; -using OnnxStack.Core.Config; -using OnnxStack.Core.Services; -using OnnxStack.ImageUpscaler.Config; -using OnnxStack.ImageUpscaler.Services; - -namespace OnnxStack.ImageUpscaler -{ - public static class Registration - { - /// - /// Register OnnxStack ImageUpscaler services - /// - /// The service collection. - public static void AddOnnxStackImageUpscaler(this IServiceCollection serviceCollection) - { - serviceCollection.AddOnnxStack(); - serviceCollection.RegisterServices(); - serviceCollection.AddSingleton(TryLoadAppSettings()); - } - - - /// - /// Register OnnxStack ImageUpscaler services, AddOnnxStack() must be called before - /// - /// The service collection. - /// The configuration. - public static void AddOnnxStackImageUpscaler(this IServiceCollection serviceCollection, ImageUpscalerConfig configuration) - { - serviceCollection.RegisterServices(); - serviceCollection.AddSingleton(configuration); - } - - - /// - /// Registers the services. - /// - /// The service collection. - private static void RegisterServices(this IServiceCollection serviceCollection) - { - serviceCollection.AddSingleton(); - serviceCollection.AddSingleton(); - } - - - /// - /// Try load ImageUpscalerConfig from application settings. - /// - /// - private static ImageUpscalerConfig TryLoadAppSettings() - { - try - { - return ConfigManager.LoadConfiguration(); - } - catch - { - return new ImageUpscalerConfig(); - } - } - } -} \ No newline at end of file diff --git a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs b/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs deleted file mode 100644 index ba71b7e9..00000000 --- a/OnnxStack.ImageUpscaler/Services/IUpscaleService.cs +++ /dev/null @@ -1,100 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Image; -using OnnxStack.Core.Video; -using OnnxStack.ImageUpscaler.Common; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.ImageUpscaler.Services -{ - public interface IUpscaleService - { - - /// - /// Loads the model. - /// - /// The model. - /// - Task LoadModelAsync(UpscaleModelSet model); - - /// - /// Unloads the model. - /// - /// The model. - /// - Task UnloadModelAsync(UpscaleModelSet model); - - /// - /// Determines whether [is model loaded] [the specified model options]. - /// - /// The model options. - /// - /// true if [is model loaded] [the specified model options]; otherwise, false. - /// - bool IsModelLoaded(UpscaleModelSet modelOptions); - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - Task GenerateAsImageAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - Task GenerateAsByteAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - Task GenerateAsByteAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default); - } -} diff --git a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs b/OnnxStack.ImageUpscaler/Services/UpscaleService.cs deleted file mode 100644 index 8b0c3e37..00000000 --- a/OnnxStack.ImageUpscaler/Services/UpscaleService.cs +++ /dev/null @@ -1,314 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core; -using OnnxStack.Core.Config; -using OnnxStack.Core.Image; -using OnnxStack.Core.Model; -using OnnxStack.Core.Services; -using OnnxStack.Core.Video; -using OnnxStack.ImageUpscaler.Common; -using OnnxStack.ImageUpscaler.Extensions; -using OnnxStack.ImageUpscaler.Models; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.ImageUpscaler.Services -{ - public class UpscaleService : IUpscaleService - { - private readonly IVideoService _videoService; - private readonly Dictionary _modelSessions; - - /// - /// Initializes a new instance of the class. - /// - /// The configuration. - /// The model service. - /// The image service. - public UpscaleService(IVideoService videoService) - { - _videoService = videoService; - _modelSessions = new Dictionary(); - } - - - /// - /// Loads the model. - /// - /// The model. - /// - public Task LoadModelAsync(UpscaleModelSet model) - { - if (_modelSessions.ContainsKey(model)) - return Task.FromResult(true); - - return Task.FromResult(_modelSessions.TryAdd(model, CreateModelSession(model))); - } - - - /// - /// Unloads the model. - /// - /// The model. - /// - public Task UnloadModelAsync(UpscaleModelSet model) - { - if (_modelSessions.Remove(model, out var session)) - { - session?.Dispose(); - } - return Task.FromResult(true); - } - - - /// - /// Determines whether [is model loaded] [the specified model options]. - /// - /// The model options. - /// - /// true if [is model loaded] [the specified model options]; otherwise, false. - /// - /// - public bool IsModelLoaded(UpscaleModelSet modelOptions) - { - return _modelSessions.ContainsKey(modelOptions); - } - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - public async Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) - { - return await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - } - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - public async Task GenerateAsImageAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) - { - var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne); - } - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) - { - var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - using (var memoryStream = new MemoryStream()) - using (var image = new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne)) - { - await image.SaveAsync(memoryStream); - return memoryStream.ToArray(); - } - } - - - /// - /// Generates the upscaled image. - /// - /// The model options. - /// The input image. - /// - public async Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) - { - var imageTensor = await GenerateInternalAsync(modelOptions, inputImage, cancellationToken); - using (var image = new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne)) - { - var memoryStream = new MemoryStream(); - await image.SaveAsync(memoryStream); - return memoryStream; - } - } - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default) - { - var videoInfo = await _videoService.GetVideoInfoAsync(videoInput); - var tensorFrames = await GenerateInternalAsync(modelOptions, videoInput, videoInfo, cancellationToken); - - DenseTensor videoResult = default; - foreach (var tensorFrame in tensorFrames) - { - cancellationToken.ThrowIfCancellationRequested(); - videoResult = videoResult.Concatenate(tensorFrame); - } - return videoResult; - } - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task GenerateAsByteAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken) - { - var outputTasks = new List>(); - var videoInfo = await _videoService.GetVideoInfoAsync(videoInput); - var tensorFrames = await GenerateInternalAsync(modelOptions, videoInput, videoInfo, cancellationToken); - foreach (DenseTensor tensorFrame in tensorFrames) - { - outputTasks.Add(Task.Run(async () => - { - cancellationToken.ThrowIfCancellationRequested(); - - using (var imageStream = new MemoryStream()) - using (var imageFrame = new OnnxImage(tensorFrame, ImageNormalizeType.ZeroToOne)) - { - await imageFrame.SaveAsync(imageStream); - return imageStream.ToArray(); - } - })); - } - - var output = await Task.WhenAll(outputTasks); - var videoResult = await _videoService.CreateVideoAsync(output, videoInfo.FPS); - return videoResult.Data; - } - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task GenerateAsStreamAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken) - { - return new MemoryStream(await GenerateAsByteAsync(modelOptions, videoInput, cancellationToken)); - } - - - /// - /// Generates an upscaled image of the source provided. - /// - /// The model options. - /// The input image. - private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken) - { - if (!_modelSessions.TryGetValue(modelSet, out var modelSession)) - throw new System.Exception("Model not loaded"); - - var upscaleInput = CreateInputParams(inputImage, modelSession.SampleSize, modelSession.ScaleFactor); - var metadata = await modelSession.GetMetadataAsync(); - - var outputTensor = new DenseTensor(new[] { 1, modelSession.Channels, upscaleInput.OutputHeight, upscaleInput.OutputWidth }); - foreach (var imageTile in upscaleInput.ImageTiles) - { - cancellationToken.ThrowIfCancellationRequested(); - - var outputDimension = new[] { 1, modelSession.Channels, imageTile.Destination.Height, imageTile.Destination.Width }; - var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await modelSession.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); - } - } - } - - return outputTensor; - - } - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task>> GenerateInternalAsync(UpscaleModelSet modelSet, VideoInput videoInput, VideoInfo videoInfo, CancellationToken cancellationToken) - { - if (!_modelSessions.TryGetValue(modelSet, out var modelSession)) - throw new System.Exception("Model not loaded"); - - var videoFrames = await _videoService.CreateFramesAsync(videoInput, videoInfo.FPS); - var metadata = await modelSession.GetMetadataAsync(); - - // Create Inputs - var outputTensors = new List>(); - foreach (var frame in videoFrames.Frames) - { - using (var imageFrame = new OnnxImage(frame.Frame)) - { - var input = CreateInputParams(imageFrame, modelSession.SampleSize, modelSession.ScaleFactor); - var outputDimension = new[] { 1, modelSession.Channels, 0, 0 }; - var outputTensor = new DenseTensor(new[] { 1, modelSession.Channels, input.OutputHeight, input.OutputWidth }); - foreach (var imageTile in input.ImageTiles) - { - var inputTensor = imageTile.Image.GetImageTensor(ImageNormalizeType.ZeroToOne, modelSession.Channels); - outputDimension[2] = imageTile.Destination.Height; - outputDimension[3] = imageTile.Destination.Width; - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await modelSession.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - outputTensor.ApplyImageTile(result.ToDenseTensor(), imageTile.Destination); - } - } - } - outputTensors.Add(outputTensor); - } - } - return outputTensors; - } - - - /// - /// Creates the input parameters. - /// - /// The image source. - /// Maximum size of the tile. - /// The scale factor. - /// - private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor) - { - var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor); - var width = imageSource.Width * scaleFactor; - var height = imageSource.Height * scaleFactor; - return new UpscaleInput(tiles, width, height); - } - - - private UpscaleModel CreateModelSession(UpscaleModelSet modelSet) - { - return new UpscaleModel(modelSet.UpscaleModelConfig.ApplyDefaults(modelSet)); - } - } -} \ No newline at end of file diff --git a/OnnxStack.UI/App.xaml.cs b/OnnxStack.UI/App.xaml.cs index b7b090ad..1de842c9 100644 --- a/OnnxStack.UI/App.xaml.cs +++ b/OnnxStack.UI/App.xaml.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using OnnxStack.Core; +using OnnxStack.Core.Services; using OnnxStack.ImageUpscaler; using OnnxStack.StableDiffusion.Config; using OnnxStack.UI.Dialogs; @@ -28,7 +29,6 @@ public App() // Add OnnxStackStableDiffusion builder.Services.AddOnnxStack(); - builder.Services.AddOnnxStackImageUpscaler(); builder.Services.AddOnnxStackConfig(); // Add Windows @@ -49,6 +49,8 @@ public App() builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(); // Build App diff --git a/OnnxStack.UI/MainWindow.xaml.cs b/OnnxStack.UI/MainWindow.xaml.cs index 3e14c02f..ec4bd3ea 100644 --- a/OnnxStack.UI/MainWindow.xaml.cs +++ b/OnnxStack.UI/MainWindow.xaml.cs @@ -1,7 +1,5 @@ using Microsoft.Extensions.Logging; using Microsoft.Win32; -using OnnxStack.ImageUpscaler.Config; -using OnnxStack.StableDiffusion.Config; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using OnnxStack.UI.Views; diff --git a/OnnxStack.UI/Services/IUpscaleService.cs b/OnnxStack.UI/Services/IUpscaleService.cs new file mode 100644 index 00000000..ed1712c6 --- /dev/null +++ b/OnnxStack.UI/Services/IUpscaleService.cs @@ -0,0 +1,56 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; +using OnnxStack.Core.Video; +using OnnxStack.ImageUpscaler.Common; +using SixLabors.ImageSharp; +using SixLabors.ImageSharp.PixelFormats; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.UI.Services +{ + public interface IUpscaleService + { + + /// + /// Loads the model. + /// + /// The model. + /// + Task LoadModelAsync(UpscaleModelSet model); + + /// + /// Unloads the model. + /// + /// The model. + /// + Task UnloadModelAsync(UpscaleModelSet model); + + /// + /// Determines whether [is model loaded] [the specified model options]. + /// + /// The model options. + /// + /// true if [is model loaded] [the specified model options]; otherwise, false. + /// + bool IsModelLoaded(UpscaleModelSet modelOptions); + + /// + /// Generates the upscaled image. + /// + /// The model options. + /// The input image. + /// + Task GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); + + + /// + /// Generates the upscaled video. + /// + /// The model options. + /// The video input. + /// + Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default); + } +} diff --git a/OnnxStack.UI/Services/UpscaleService.cs b/OnnxStack.UI/Services/UpscaleService.cs new file mode 100644 index 00000000..419de2d2 --- /dev/null +++ b/OnnxStack.UI/Services/UpscaleService.cs @@ -0,0 +1,154 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Image; +using OnnxStack.Core.Services; +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Pipelines; +using OnnxStack.ImageUpscaler.Common; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.UI.Services +{ + public class UpscaleService : IUpscaleService + { + private readonly IVideoService _videoService; + private readonly ILogger _logger; + private readonly Dictionary _pipelines; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The model service. + /// The image service. + public UpscaleService(IVideoService videoService) + { + _videoService = videoService; + _pipelines = new Dictionary(); + } + + + /// + /// Loads the model. + /// + /// The model. + /// + public async Task LoadModelAsync(UpscaleModelSet model) + { + if (_pipelines.ContainsKey(model)) + return true; + + var pipeline = CreatePipeline(model); + await pipeline.LoadAsync(); + return _pipelines.TryAdd(model, pipeline); + } + + + /// + /// Unloads the model. + /// + /// The model. + /// + public async Task UnloadModelAsync(UpscaleModelSet model) + { + if (_pipelines.Remove(model, out var pipeline)) + { + await pipeline?.UnloadAsync(); + } + return true; + } + + + /// + /// Determines whether [is model loaded] [the specified model options]. + /// + /// The model options. + /// + /// true if [is model loaded] [the specified model options]; otherwise, false. + /// + /// + public bool IsModelLoaded(UpscaleModelSet modelOptions) + { + return _pipelines.ContainsKey(modelOptions); + } + + + /// + /// Generates the upscaled image. + /// + /// The model options. + /// The input image. + /// + public async Task GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) + { + return new OnnxImage(await GenerateInternalAsync(modelOptions, inputImage, cancellationToken), ImageNormalizeType.ZeroToOne); + } + + + + + /// + /// Generates the upscaled video. + /// + /// The model options. + /// The video input. + /// + public async Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default) + { + var videoInfo = await _videoService.GetVideoInfoAsync(videoInput); + var tensorFrames = await GenerateInternalAsync(modelOptions, videoInput, videoInfo, cancellationToken); + + DenseTensor videoResult = default; + foreach (var tensorFrame in tensorFrames) + { + cancellationToken.ThrowIfCancellationRequested(); + videoResult = videoResult.Concatenate(tensorFrame); + } + return videoResult; + } + + + /// + /// Generates an upscaled image of the source provided. + /// + /// The model options. + /// The input image. + private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken) + { + if (!_pipelines.TryGetValue(modelSet, out var pipeline)) + throw new Exception("Pipeline not found or is unsupported"); + + return await pipeline.RunAsync(inputImage, cancellationToken); + } + + + /// + /// Generates the upscaled video. + /// + /// The model options. + /// The video input. + /// + public async Task>> GenerateInternalAsync(UpscaleModelSet modelSet, VideoInput videoInput, VideoInfo videoInfo, CancellationToken cancellationToken) + { + if (!_pipelines.TryGetValue(modelSet, out var pipeline)) + throw new Exception("Pipeline not found or is unsupported"); + + return new List>(); + } + + + + + + private ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet) + { + return ImageUpscalePipeline.CreatePipeline(modelSet, _logger); + } + + } +} \ No newline at end of file diff --git a/OnnxStack.UI/UserControls/UpscalePickerControl.xaml.cs b/OnnxStack.UI/UserControls/UpscalePickerControl.xaml.cs index da1d2b6f..37749646 100644 --- a/OnnxStack.UI/UserControls/UpscalePickerControl.xaml.cs +++ b/OnnxStack.UI/UserControls/UpscalePickerControl.xaml.cs @@ -1,8 +1,8 @@ using Microsoft.Extensions.Logging; using OnnxStack.Core; -using OnnxStack.ImageUpscaler.Services; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; +using OnnxStack.UI.Services; using System; using System.ComponentModel; using System.Linq; diff --git a/OnnxStack.UI/Views/UpscaleView.xaml.cs b/OnnxStack.UI/Views/UpscaleView.xaml.cs index 78250634..cbbd461a 100644 --- a/OnnxStack.UI/Views/UpscaleView.xaml.cs +++ b/OnnxStack.UI/Views/UpscaleView.xaml.cs @@ -1,8 +1,8 @@ using Microsoft.Extensions.Logging; using OnnxStack.Core.Image; -using OnnxStack.ImageUpscaler.Services; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; +using OnnxStack.UI.Services; using System; using System.Collections.ObjectModel; using System.ComponentModel; @@ -199,11 +199,11 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var resultBytes = await _upscaleService.GenerateAsByteAsync(SelectedModel.ModelSet, new OnnxImage(InputImage.GetImageBytes()), _cancelationTokenSource.Token); + var resultBytes = await _upscaleService.GenerateAsync(SelectedModel.ModelSet, new OnnxImage(InputImage.GetImageBytes()), _cancelationTokenSource.Token); if (resultBytes != null) { var elapsed = Stopwatch.GetElapsedTime(timestamp).TotalSeconds; - var imageResult = new UpscaleResult(Utils.CreateBitmap(resultBytes), UpscaleInfo with { }, elapsed); + var imageResult = new UpscaleResult(Utils.CreateBitmap(resultBytes.GetImageBytes()), UpscaleInfo with { }, elapsed); ResultImage = imageResult; HasResult = true; diff --git a/OnnxStack.UI/appsettings.json b/OnnxStack.UI/appsettings.json index a5a7a5d5..90d2d4de 100644 --- a/OnnxStack.UI/appsettings.json +++ b/OnnxStack.UI/appsettings.json @@ -18,15 +18,15 @@ "ModelSet": { "Name": "RealSR BSRGAN x4", "IsEnabled": true, - "Channels": 3, - "SampleSize": 512, - "ScaleFactor": 4, "DeviceId": 0, "InterOpNumThreads": 0, "IntraOpNumThreads": 0, "ExecutionMode": "ORT_SEQUENTIAL", "ExecutionProvider": "DirectML", "UpscaleModelConfig": { + "Channels": 3, + "ScaleFactor": 4, + "SampleSize": 512, "OnnxModelPath": "D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx" } } From 13e90c2429478a104d91e2d3a5c30a7bbfb2ad1d Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Tue, 13 Feb 2024 20:27:07 +1300 Subject: [PATCH 3/5] Wrap and simplify Video input --- .../Examples/FeatureExtractorExample.cs | 4 +- .../Examples/StableDiffusionBatch.cs | 2 +- .../Examples/StableDiffusionExample.cs | 7 +- .../Examples/VideoToVideoExample.cs | 16 +- OnnxStack.Core/Config/OnnxModelSetConfig.cs | 16 - OnnxStack.Core/Constants.cs | 18 - OnnxStack.Core/Extensions/TensorExtension.cs | 21 + OnnxStack.Core/Image/OnnxImage.cs | 6 +- .../Model/OnnxInferenceParameters.cs | 2 +- OnnxStack.Core/Model/OnnxMetadata.cs | 2 +- OnnxStack.Core/Model/OnnxNamedMetadata.cs | 2 +- OnnxStack.Core/Model/OnnxValueCollection.cs | 2 +- OnnxStack.Core/Registration.cs | 1 - OnnxStack.Core/Services/IVideoService.cs | 109 ----- OnnxStack.Core/Services/VideoService.cs | 460 ------------------ OnnxStack.Core/Video/Extensions.cs | 49 -- OnnxStack.Core/Video/OnnxVideo.cs | 161 ++++++ OnnxStack.Core/Video/VideoFrame.cs | 9 - OnnxStack.Core/Video/VideoFrames.cs | 6 - OnnxStack.Core/Video/VideoHelper.cs | 310 ++++++++++++ OnnxStack.Core/Video/VideoInfo.cs | 11 +- OnnxStack.Core/Video/VideoInput.cs | 118 ----- OnnxStack.Core/Video/VideoOutput.cs | 4 - .../Pipelines/FeatureExtractorPipeline.cs | 13 +- OnnxStack.FeatureExtractor/README.md | 25 +- .../Common/BatchResult.cs | 4 +- .../Config/PromptOptions.cs | 6 +- .../Pipelines/Base/IPipeline.cs | 52 ++ .../Pipelines/Base/PipelineBase.cs | 71 ++- .../Pipelines/StableDiffusionPipeline.cs | 223 ++++++++- OnnxStack.UI/App.xaml.cs | 4 - OnnxStack.UI/Services/IUpscaleService.cs | 2 +- .../UserControls/SchedulerControl.xaml.cs | 3 +- 33 files changed, 872 insertions(+), 867 deletions(-) delete mode 100644 OnnxStack.Core/Config/OnnxModelSetConfig.cs delete mode 100644 OnnxStack.Core/Constants.cs delete mode 100644 OnnxStack.Core/Services/IVideoService.cs delete mode 100644 OnnxStack.Core/Services/VideoService.cs delete mode 100644 OnnxStack.Core/Video/Extensions.cs create mode 100644 OnnxStack.Core/Video/OnnxVideo.cs delete mode 100644 OnnxStack.Core/Video/VideoFrame.cs delete mode 100644 OnnxStack.Core/Video/VideoFrames.cs create mode 100644 OnnxStack.Core/Video/VideoHelper.cs delete mode 100644 OnnxStack.Core/Video/VideoInput.cs delete mode 100644 OnnxStack.Core/Video/VideoOutput.cs diff --git a/OnnxStack.Console/Examples/FeatureExtractorExample.cs b/OnnxStack.Console/Examples/FeatureExtractorExample.cs index 3cbc14cc..c45e7b6d 100644 --- a/OnnxStack.Console/Examples/FeatureExtractorExample.cs +++ b/OnnxStack.Console/Examples/FeatureExtractorExample.cs @@ -1,4 +1,5 @@ using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.FeatureExtractor.Pipelines; using OnnxStack.StableDiffusion.Config; using SixLabors.ImageSharp; @@ -47,7 +48,7 @@ public async Task RunAsync() var timestamp = Stopwatch.GetTimestamp(); OutputHelpers.WriteConsole($"Load pipeline`{pipeline.Name}`", ConsoleColor.Cyan); - // Run Pipeline + // Run Image Pipeline var imageFeature = await pipeline.RunAsync(inputImage); OutputHelpers.WriteConsole($"Generating image", ConsoleColor.Cyan); @@ -55,7 +56,6 @@ public async Task RunAsync() // Save Image await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png")); - OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Cyan); //Unload diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index 6b48dffe..8259b0d4 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -61,7 +61,7 @@ public async Task RunAsync() await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, progressCallback: OutputHelpers.BatchProgressCallback)) { // Create Image from Tensor result - var image = result.ImageResult; + var image = new OnnxImage(result.Result); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{result.SchedulerOptions.Seed}.png"); diff --git a/OnnxStack.Console/Examples/StableDiffusionExample.cs b/OnnxStack.Console/Examples/StableDiffusionExample.cs index 09564d50..ac3ea81d 100644 --- a/OnnxStack.Console/Examples/StableDiffusionExample.cs +++ b/OnnxStack.Console/Examples/StableDiffusionExample.cs @@ -67,14 +67,11 @@ public async Task RunAsync() OutputHelpers.WriteConsole($"Generating '{schedulerType}' Image...", ConsoleColor.Green); // Run pipeline - var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback); - - // Create Image from Tensor result - var image = new OnnxImage(result); + var result = await pipeline.GenerateImageAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback); // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}_{schedulerOptions.SchedulerType}.png"); - await image.SaveAsync(outputFilename); + await result.SaveAsync(outputFilename); OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}, Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Green); } diff --git a/OnnxStack.Console/Examples/VideoToVideoExample.cs b/OnnxStack.Console/Examples/VideoToVideoExample.cs index f228818c..8b87d55b 100644 --- a/OnnxStack.Console/Examples/VideoToVideoExample.cs +++ b/OnnxStack.Console/Examples/VideoToVideoExample.cs @@ -1,6 +1,4 @@ -using OnnxStack.Core.Services; -using OnnxStack.Core.Video; -using OnnxStack.StableDiffusion.Common; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Pipelines; @@ -11,12 +9,10 @@ public sealed class VideoToVideoExample : IExampleRunner { private readonly string _outputDirectory; private readonly StableDiffusionConfig _configuration; - private readonly IVideoService _videoService; - public VideoToVideoExample(StableDiffusionConfig configuration, IVideoService videoService) + public VideoToVideoExample(StableDiffusionConfig configuration) { _configuration = configuration; - _videoService = videoService; _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(VideoToVideoExample)); Directory.CreateDirectory(_outputDirectory); } @@ -30,8 +26,7 @@ public VideoToVideoExample(StableDiffusionConfig configuration, IVideoService vi public async Task RunAsync() { // Load Video - var targetFPS = 15; - var videoInput = await VideoInput.FromFileAsync("C:\\Users\\Deven\\Pictures\\gidsgphy.gif", targetFPS); + var videoInput = await OnnxVideo.FromFileAsync("C:\\Users\\Deven\\Pictures\\gidsgphy.gif"); // Loop though the appsettings.json model sets foreach (var modelSet in _configuration.ModelSets) @@ -53,11 +48,10 @@ public async Task RunAsync() }; // Run pipeline - var result = await pipeline.RunAsync(promptOptions, progressCallback: OutputHelpers.FrameProgressCallback); + var result = await pipeline.GenerateVideoAsync(promptOptions, progressCallback: OutputHelpers.FrameProgressCallback); // Save Video File - var outputFilename = Path.Combine(_outputDirectory, $"{modelSet.Name}.mp4"); - await VideoInput.SaveFileAsync(result, outputFilename, targetFPS); + await result.SaveAsync(Path.Combine(_outputDirectory, $"Result.mp4")); } } } diff --git a/OnnxStack.Core/Config/OnnxModelSetConfig.cs b/OnnxStack.Core/Config/OnnxModelSetConfig.cs deleted file mode 100644 index c9e9432a..00000000 --- a/OnnxStack.Core/Config/OnnxModelSetConfig.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Microsoft.ML.OnnxRuntime; - -namespace OnnxStack.Core.Config -{ - public class OnnxModelSetConfig : IOnnxModelSetConfig - { - public string Name { get; set; } - public bool IsEnabled { get; set; } - public int DeviceId { get; set; } - public string OnnxModelPath { get; set; } - public int InterOpNumThreads { get; set; } - public int IntraOpNumThreads { get; set; } - public ExecutionMode ExecutionMode { get; set; } - public ExecutionProvider ExecutionProvider { get; set; } - } -} diff --git a/OnnxStack.Core/Constants.cs b/OnnxStack.Core/Constants.cs deleted file mode 100644 index b712fcdb..00000000 --- a/OnnxStack.Core/Constants.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System.Collections.Generic; - -namespace OnnxStack.Core -{ - public static class Constants - { - /// - /// The width/height valid sizes - /// - public static readonly IReadOnlyList ValidSizes; - - static Constants() - { - // Cache an array with enough blank tokens to fill an empty prompt - ValidSizes = new List { 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 }; - } - } -} diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index 45441c4a..a6c6456d 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -1,5 +1,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; +using System.Collections.Generic; namespace OnnxStack.Core { @@ -54,6 +55,26 @@ public static void NormalizeMinMax(this DenseTensor tensor) } + /// + /// Splits the tensor across the batch dimension. + /// + /// The tensor. + /// + public static IEnumerable> SplitBatch(this DenseTensor tensor) + { + var count = tensor.Dimensions[0]; + var dimensions = tensor.Dimensions.ToArray(); + dimensions[0] = 1; + + var newLength = (int)tensor.Length / count; + for (int i = 0; i < count; i++) + { + var start = i * newLength; + yield return new DenseTensor(tensor.Buffer.Slice(start, newLength), dimensions); + } + } + + /// /// Concatenates the specified tensors along the specified axis. /// diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index 45f5cd3a..d4d837bb 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -10,7 +10,7 @@ namespace OnnxStack.Core.Image { - public class OnnxImage : IDisposable + public sealed class OnnxImage : IDisposable { private readonly Image _imageData; @@ -222,6 +222,10 @@ public void Resize(int height, int width, ResizeMode resizeMode = ResizeMode.Cro }); } + public OnnxImage Clone() + { + return new OnnxImage(_imageData); + } /// /// Saves the specified image to file. diff --git a/OnnxStack.Core/Model/OnnxInferenceParameters.cs b/OnnxStack.Core/Model/OnnxInferenceParameters.cs index ad813702..fbdb7389 100644 --- a/OnnxStack.Core/Model/OnnxInferenceParameters.cs +++ b/OnnxStack.Core/Model/OnnxInferenceParameters.cs @@ -5,7 +5,7 @@ namespace OnnxStack.Core.Model { - public class OnnxInferenceParameters : IDisposable + public sealed class OnnxInferenceParameters : IDisposable { private readonly RunOptions _runOptions; private readonly OnnxMetadata _metadata; diff --git a/OnnxStack.Core/Model/OnnxMetadata.cs b/OnnxStack.Core/Model/OnnxMetadata.cs index 278062f4..dc0c4f42 100644 --- a/OnnxStack.Core/Model/OnnxMetadata.cs +++ b/OnnxStack.Core/Model/OnnxMetadata.cs @@ -2,7 +2,7 @@ namespace OnnxStack.Core.Model { - public record OnnxMetadata + public sealed record OnnxMetadata { /// /// Gets or sets the inputs. diff --git a/OnnxStack.Core/Model/OnnxNamedMetadata.cs b/OnnxStack.Core/Model/OnnxNamedMetadata.cs index c0207bbf..2c6ed585 100644 --- a/OnnxStack.Core/Model/OnnxNamedMetadata.cs +++ b/OnnxStack.Core/Model/OnnxNamedMetadata.cs @@ -3,7 +3,7 @@ namespace OnnxStack.Core.Model { - public record OnnxNamedMetadata(string Name, NodeMetadata Value) + public sealed record OnnxNamedMetadata(string Name, NodeMetadata Value) { internal static OnnxNamedMetadata Create(KeyValuePair metadata) { diff --git a/OnnxStack.Core/Model/OnnxValueCollection.cs b/OnnxStack.Core/Model/OnnxValueCollection.cs index d0b3fd8f..c651425d 100644 --- a/OnnxStack.Core/Model/OnnxValueCollection.cs +++ b/OnnxStack.Core/Model/OnnxValueCollection.cs @@ -4,7 +4,7 @@ namespace OnnxStack.Core.Model { - public class OnnxValueCollection : IDisposable + public sealed class OnnxValueCollection : IDisposable { private readonly List _metaData; private readonly Dictionary _values; diff --git a/OnnxStack.Core/Registration.cs b/OnnxStack.Core/Registration.cs index df89d72a..8a027ab6 100644 --- a/OnnxStack.Core/Registration.cs +++ b/OnnxStack.Core/Registration.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.DependencyInjection; using OnnxStack.Common.Config; using OnnxStack.Core.Config; -using OnnxStack.Core.Services; namespace OnnxStack.Core { diff --git a/OnnxStack.Core/Services/IVideoService.cs b/OnnxStack.Core/Services/IVideoService.cs deleted file mode 100644 index d1e3b79e..00000000 --- a/OnnxStack.Core/Services/IVideoService.cs +++ /dev/null @@ -1,109 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Video; -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.Core.Services -{ - /// - /// Service with basic handling of video for use in OnnxStack, Frame->Video and Video->Frames - /// - public interface IVideoService - { - /// - /// Gets the video information asynchronous. - /// - /// The video bytes. - /// The cancellation token. - /// - Task GetVideoInfoAsync(byte[] videoBytes, CancellationToken cancellationToken = default); - - /// - /// Gets the video information asynchronous. - /// - /// The video stream. - /// The cancellation token. - /// - Task GetVideoInfoAsync(Stream videoStream, CancellationToken cancellationToken = default); - - /// - /// Gets the video information, Size, FPS, Duration etc. - /// - /// The video input. - /// The cancellation token. - /// - /// No video data found - Task GetVideoInfoAsync(VideoInput videoInput, CancellationToken cancellationToken = default); - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video bytes. - /// The video FPS. - /// The cancellation token. - /// - Task CreateFramesAsync(byte[] videoBytes, float? videoFPS = default, CancellationToken cancellationToken = default); - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video stream. - /// The video FPS. - /// The cancellation token. - /// - Task CreateFramesAsync(Stream videoStream, float? videoFPS = default, CancellationToken cancellationToken = default); - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video input. - /// The video FPS. - /// The cancellation token. - /// - /// VideoTensor not supported - /// No video data found - Task CreateFramesAsync(VideoInput videoInput, float? videoFPS = default, CancellationToken cancellationToken = default); - - - /// - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video frames. - /// The video FPS. - /// The cancellation token. - /// - Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default); - - - /// - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video frames. - /// The cancellation token. - /// - Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default); - - // - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video frames. - /// The video FPS. - /// The cancellation token. - /// - Task CreateVideoAsync(DenseTensor videoTensor, float videoFPS, CancellationToken cancellationToken = default); - - /// - /// Streams frames as PNG as they are processed from a video source - /// - /// The video bytes. - /// The target FPS. - /// The cancellation token. - /// - IAsyncEnumerable StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default); - } -} \ No newline at end of file diff --git a/OnnxStack.Core/Services/VideoService.cs b/OnnxStack.Core/Services/VideoService.cs deleted file mode 100644 index 3779c2ed..00000000 --- a/OnnxStack.Core/Services/VideoService.cs +++ /dev/null @@ -1,460 +0,0 @@ -using FFMpegCore; -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Config; -using OnnxStack.Core.Video; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.Core.Services -{ - /// - /// Service with basic handling of video for use in OnnxStack, Frame->Video and Video->Frames - /// - public class VideoService : IVideoService - { - private readonly OnnxStackConfig _configuration; - - /// - /// Initializes a new instance of the class. - /// - /// The configuration. - public VideoService(OnnxStackConfig configuration) - { - _configuration = configuration; - } - - #region Public Members - - /// - /// Gets the video information, Size, FPS, Duration etc. - /// - /// The video input. - /// The cancellation token. - /// - /// VideoTensor not supported - /// No video data found - public async Task GetVideoInfoAsync(VideoInput videoInput, CancellationToken cancellationToken = default) - { - if (videoInput.VideoBytes is not null) - return await GetVideoInfoAsync(videoInput.VideoBytes, cancellationToken); - if (videoInput.VideoStream is not null) - return await GetVideoInfoAsync(videoInput.VideoStream, cancellationToken); - if (videoInput.VideoTensor is not null) - throw new NotSupportedException("VideoTensor not supported"); - - throw new ArgumentException("No video data found"); - } - - - /// - /// Gets the video information asynchronous. - /// - /// The video stream. - /// The cancellation token. - /// - public async Task GetVideoInfoAsync(Stream videoStream, CancellationToken cancellationToken = default) - { - using (var memoryStream = new MemoryStream()) - { - await memoryStream.CopyToAsync(videoStream, cancellationToken); - return await GetVideoInfoInternalAsync(memoryStream, cancellationToken); - } - } - - - /// - /// Gets the video information asynchronous. - /// - /// The video bytes. - /// The cancellation token. - /// - public async Task GetVideoInfoAsync(byte[] videoBytes, CancellationToken cancellationToken = default) - { - using (var videoStream = new MemoryStream(videoBytes)) - { - return await GetVideoInfoInternalAsync(videoStream, cancellationToken); - } - } - - - /// - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video frames. - /// The cancellation token. - /// - public async Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default) - { - return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.Info.FPS, cancellationToken); - } - - - /// - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video tensor. - /// The video FPS. - /// The cancellation token. - /// - public async Task CreateVideoAsync(DenseTensor videoTensor, float videoFPS, CancellationToken cancellationToken = default) - { - var videoFrames = await videoTensor - .ToVideoFramesAsBytesAsync() - .Select(x => new VideoFrame(x)) - .ToListAsync(cancellationToken); - return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken); - } - - - /// - /// Creates and MP4 video from a collection of PNG images. - /// - /// The video frames. - /// The video FPS. - /// The cancellation token. - /// - public async Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default) - { - var frames = videoFrames.Select(x => new VideoFrame(x)); - return await CreateVideoInternalAsync(frames, videoFPS, cancellationToken); - } - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video input. - /// The video FPS. - /// The cancellation token. - /// - /// VideoTensor not supported - /// No video data found - public async Task CreateFramesAsync(VideoInput videoInput, float? videoFPS = default, CancellationToken cancellationToken = default) - { - - if (videoInput.VideoBytes is not null) - return await CreateFramesAsync(videoInput.VideoBytes, videoFPS, cancellationToken); - if (videoInput.VideoStream is not null) - return await CreateFramesAsync(videoInput.VideoStream, videoFPS, cancellationToken); - if (videoInput.VideoFrames is not null) - return videoInput.VideoFrames; - if (videoInput.VideoTensor is not null) - throw new NotSupportedException("VideoTensor not supported"); - - throw new ArgumentException("No video data found"); - } - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video bytes. - /// The video FPS. - /// The cancellation token. - /// - public async Task CreateFramesAsync(byte[] videoBytes, float? videoFPS = default, CancellationToken cancellationToken = default) - { - var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken); - var targetFPS = videoFPS ?? videoInfo.FPS; - var videoFrames = await CreateFramesInternalAsync(videoBytes, targetFPS, cancellationToken).ToListAsync(cancellationToken); - videoInfo = videoInfo with { FPS = targetFPS }; - return new VideoFrames(videoInfo, videoFrames); - } - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video stream. - /// The video FPS. - /// The cancellation token. - /// - public async Task CreateFramesAsync(Stream videoStream, float? videoFPS = default, CancellationToken cancellationToken = default) - { - using (var memoryStream = new MemoryStream()) - { - await memoryStream.CopyToAsync(videoStream, cancellationToken).ConfigureAwait(false); - var videoBytes = memoryStream.ToArray(); - var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken); - var targetFPS = videoFPS ?? videoInfo.FPS; - var videoFrames = await CreateFramesInternalAsync(videoBytes, targetFPS, cancellationToken).ToListAsync(cancellationToken); - videoInfo = videoInfo with { FPS = targetFPS }; - return new VideoFrames(videoInfo, videoFrames); - } - } - - - /// - /// Streams frames as PNG as they are processed from a video source - /// - /// The video bytes. - /// The target FPS. - /// The cancellation token. - /// - public IAsyncEnumerable StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default) - { - return CreateFramesInternalAsync(videoBytes, targetFPS, cancellationToken); - } - - #endregion - - #region Private Members - - - /// - /// Gets the video information. - /// - /// The video stream. - /// The cancellation token. - /// - private async Task GetVideoInfoInternalAsync(MemoryStream videoStream, CancellationToken cancellationToken = default) - { - var result = await FFProbe.AnalyseAsync(videoStream).ConfigureAwait(false); - return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate); - } - - - /// - /// Creates an MP4 video from a collection of PNG frames - /// - /// The image data. - /// The FPS. - /// The cancellation token. - /// - private async Task CreateVideoInternalAsync(IEnumerable imageData, float fps = 15, CancellationToken cancellationToken = default) - { - string tempVideoPath = GetTempFilename(); - try - { - // Analyze first fram to get some details - var frameInfo = await GetVideoInfoAsync(imageData.First().Frame); - var aspectRatio = (double)frameInfo.Width / frameInfo.Height; - using (var videoWriter = CreateWriter(tempVideoPath, fps, aspectRatio)) - { - // Start FFMPEG - videoWriter.Start(); - foreach (var image in imageData) - { - // Write each frame to the input stream of FFMPEG - await videoWriter.StandardInput.BaseStream.WriteAsync(image.Frame, cancellationToken); - } - - // Done close stream and wait for app to process - videoWriter.StandardInput.BaseStream.Close(); - await videoWriter.WaitForExitAsync(cancellationToken); - - // Read result from temp file - var videoResult = await File.ReadAllBytesAsync(tempVideoPath, cancellationToken); - - // Analyze the result - var videoInfo = await GetVideoInfoAsync(videoResult); - return new VideoOutput(videoResult, videoInfo); - } - } - finally - { - DeleteTempFile(tempVideoPath); - } - } - - - /// - /// Creates a collection of PNG frames from a video source - /// - /// The video data. - /// The FPS. - /// The cancellation token. - /// - /// Invalid PNG header - private async IAsyncEnumerable CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - string tempVideoPath = GetTempFilename(); - try - { - await File.WriteAllBytesAsync(tempVideoPath, videoData, cancellationToken); - using (var ffmpegProcess = CreateReader(tempVideoPath, fps)) - { - // Start FFMPEG - ffmpegProcess.Start(); - - // FFMPEG output stream - var processOutputStream = ffmpegProcess.StandardOutput.BaseStream; - - // Buffer to hold the current image - var buffer = new byte[20480000]; - - var currentIndex = 0; - while (!cancellationToken.IsCancellationRequested) - { - // Reset the index new PNG - currentIndex = 0; - - // Read the PNG Header - if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0) - break; - - currentIndex += 8;// header length - - if (!IsImageHeader(buffer)) - throw new Exception("Invalid PNG header"); - - // loop through each chunk - while (true) - { - // Read the chunk header - await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken); - - var chunkIndex = currentIndex; - currentIndex += 12; // Chunk header length - - // Get the chunk's content size in bytes from the header we just read - var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3]; - if (totalSize > 0) - { - var totalRead = 0; - while (totalRead < totalSize) - { - int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken); - currentIndex += read; - totalRead += read; - } - continue; - } - - // If the size is 0 and is the end of the image - if (totalSize == 0 && IsImageEnd(buffer, chunkIndex)) - break; - } - - yield return new VideoFrame(buffer[..currentIndex]); - } - - if (cancellationToken.IsCancellationRequested) - ffmpegProcess.Kill(); - } - } - finally - { - DeleteTempFile(tempVideoPath); - } - } - - - /// - /// Gets the temporary filename. - /// - /// - private string GetTempFilename() - { - if (!Directory.Exists(_configuration.TempPath)) - Directory.CreateDirectory(_configuration.TempPath); - - return Path.Combine(_configuration.TempPath, $"{Path.GetFileNameWithoutExtension(Path.GetRandomFileName())}.mp4"); - } - - - /// - /// Deletes the temporary file. - /// - /// The filename. - private void DeleteTempFile(string filename) - { - try - { - if (File.Exists(filename)) - File.Delete(filename); - } - catch (Exception) - { - // File in use, Log - } - } - - - /// - /// Creates FFMPEG video reader process. - /// - /// The input file. - /// The FPS. - /// - private Process CreateReader(string inputFile, float fps) - { - var ffmpegProcess = new Process(); - ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; - ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -i \"{inputFile}\" -c:v png -r {fps} -f image2pipe -"; - ffmpegProcess.StartInfo.RedirectStandardOutput = true; - ffmpegProcess.StartInfo.UseShellExecute = false; - ffmpegProcess.StartInfo.CreateNoWindow = true; - return ffmpegProcess; - } - - - /// - /// Creates FFMPEG video writer process. - /// - /// The output file. - /// The FPS. - /// The aspect ratio. - /// - private Process CreateWriter(string outputFile, float fps, double aspectRatio) - { - var ffmpegProcess = new Process(); - ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; - ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v libx264 -movflags +faststart -vf format=yuv420p -aspect {aspectRatio} {outputFile}"; - ffmpegProcess.StartInfo.RedirectStandardInput = true; - ffmpegProcess.StartInfo.UseShellExecute = false; - ffmpegProcess.StartInfo.CreateNoWindow = true; - return ffmpegProcess; - } - - - /// - /// Determines whether we are at the start of a PNG image in the specified buffer. - /// - /// The buffer. - /// The offset. - /// - /// true if the start of a PNG image sequence is detectedfalse. - /// - private static bool IsImageHeader(byte[] buffer) - { - // PNG Header http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html#PNG-file-signature - if (buffer[0] != 0x89 - || buffer[1] != 0x50 - || buffer[2] != 0x4E - || buffer[3] != 0x47 - || buffer[4] != 0x0D - || buffer[5] != 0x0A - || buffer[6] != 0x1A - || buffer[7] != 0x0A) - return false; - - return true; - } - - - /// - /// Determines whether we are at the end of a PNG image in the specified buffer. - /// - /// The buffer. - /// The offset. - /// - /// true if the end of a PNG image sequence is detectedfalse. - /// - private static bool IsImageEnd(byte[] buffer, int offset) - { - return buffer[offset + 4] == 0x49 // I - && buffer[offset + 5] == 0x45 // E - && buffer[offset + 6] == 0x4E // N - && buffer[offset + 7] == 0x44; // D - } - } - - #endregion -} diff --git a/OnnxStack.Core/Video/Extensions.cs b/OnnxStack.Core/Video/Extensions.cs deleted file mode 100644 index dbabfa88..00000000 --- a/OnnxStack.Core/Video/Extensions.cs +++ /dev/null @@ -1,49 +0,0 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Image; -using SixLabors.ImageSharp; -using SixLabors.ImageSharp.PixelFormats; -using System.Collections.Generic; - -namespace OnnxStack.Core.Video -{ - public static class Extensions - { - public static IEnumerable> ToVideoFrames(this DenseTensor videoTensor) - { - var count = videoTensor.Dimensions[0]; - var dimensions = videoTensor.Dimensions.ToArray(); - dimensions[0] = 1; - - var newLength = (int)videoTensor.Length / count; - for (int i = 0; i < count; i++) - { - var start = i * newLength; - yield return new DenseTensor(videoTensor.Buffer.Slice(start, newLength), dimensions); - } - } - - public static IEnumerable ToVideoFramesAsBytes(this DenseTensor videoTensor) - { - foreach (var frame in videoTensor.ToVideoFrames()) - { - yield return new OnnxImage(frame).GetImageBytes(); - } - } - - public static async IAsyncEnumerable ToVideoFramesAsBytesAsync(this DenseTensor videoTensor) - { - foreach (var frame in videoTensor.ToVideoFrames()) - { - yield return new OnnxImage(frame).GetImageBytes(); - } - } - - //public static IEnumerable> ToVideoFramesAsImage(this DenseTensor videoTensor) - //{ - // foreach (var frame in videoTensor.ToVideoFrames()) - // { - // yield return frame.ToImage(); - // } - //} - } -} diff --git a/OnnxStack.Core/Video/OnnxVideo.cs b/OnnxStack.Core/Video/OnnxVideo.cs new file mode 100644 index 00000000..a61f4fc9 --- /dev/null +++ b/OnnxStack.Core/Video/OnnxVideo.cs @@ -0,0 +1,161 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.Core.Video +{ + public sealed class OnnxVideo : IDisposable + { + private readonly VideoInfo _info; + private readonly IReadOnlyList _frames; + + /// + /// Initializes a new instance of the class. + /// + /// The information. + /// The frames. + public OnnxVideo(VideoInfo info, List frames) + { + _info = info; + _frames = frames; + } + + + /// + /// Initializes a new instance of the class. + /// + /// The information. + /// The video tensor. + public OnnxVideo(VideoInfo info, DenseTensor videoTensor) + { + _info = info; + _frames = videoTensor + .SplitBatch() + .Select(x => new OnnxImage(x)) + .ToList(); + } + + + /// + /// Gets the height. + /// + public int Height => _info.Height; + + /// + /// Gets the width. + /// + public int Width => _info.Width; + + /// + /// Gets the frame rate. + /// + public float FrameRate => _info.FrameRate; + + /// + /// Gets the duration. + /// + public TimeSpan Duration => _info.Duration; + + /// + /// Gets the information. + /// + public VideoInfo Info => _info; + + /// + /// Gets the frames. + /// + public IReadOnlyList Frames => _frames; + + /// + /// Gets the aspect ratio. + /// + public double AspectRatio => (double)_info.Width / _info.Height; + + /// + /// Gets a value indicating whether this instance has video. + /// + /// + /// true if this instance has video; otherwise, false. + /// + public bool HasVideo + { + get { return !_frames.IsNullOrEmpty(); } + } + + + /// + /// Gets the frame at the specified index. + /// + /// The index. + /// + public OnnxImage GetFrame(int index) + { + if (_frames?.Count > index) + return _frames[index]; + + return null; + } + + + /// + /// Resizes the video. + /// + /// The height. + /// The width. + public void Resize(int height, int width) + { + foreach (var frame in _frames) + frame.Resize(height, width); + + _info.Width = width; + _info.Height = height; + } + + + /// + /// Saves the video to file. + /// + /// The filename. + /// The cancellation token. + /// + public Task SaveAsync(string filename, CancellationToken cancellationToken = default) + { + return VideoHelper.WriteVideoFramesAsync(this, filename, cancellationToken); + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public void Dispose() + { + foreach (var item in _frames) + { + item?.Dispose(); + } + } + + /// + /// Load a video from file + /// + /// The filename. + /// The frame rate. + /// The cancellation token. + /// + public static async Task FromFileAsync(string filename, float? frameRate = default, CancellationToken cancellationToken = default) + { + var videoBytes = await File.ReadAllBytesAsync(filename, cancellationToken); + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoBytes); + if (frameRate.HasValue) + videoInfo = videoInfo with { FrameRate = Math.Min(videoInfo.FrameRate, frameRate.Value) }; + + var videoFrames = await VideoHelper.ReadVideoFramesAsync(videoBytes, videoInfo.FrameRate, cancellationToken); + return new OnnxVideo(videoInfo, videoFrames); + } + } +} diff --git a/OnnxStack.Core/Video/VideoFrame.cs b/OnnxStack.Core/Video/VideoFrame.cs deleted file mode 100644 index 835f8d10..00000000 --- a/OnnxStack.Core/Video/VideoFrame.cs +++ /dev/null @@ -1,9 +0,0 @@ -using OnnxStack.Core.Image; - -namespace OnnxStack.Core.Video -{ - public record VideoFrame(byte[] Frame) - { - public OnnxImage ExtraFrame { get; set; } - } -} diff --git a/OnnxStack.Core/Video/VideoFrames.cs b/OnnxStack.Core/Video/VideoFrames.cs deleted file mode 100644 index 246b7aca..00000000 --- a/OnnxStack.Core/Video/VideoFrames.cs +++ /dev/null @@ -1,6 +0,0 @@ -using System.Collections.Generic; - -namespace OnnxStack.Core.Video -{ - public record VideoFrames(VideoInfo Info, IReadOnlyList Frames); -} diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs new file mode 100644 index 00000000..ccd07cd4 --- /dev/null +++ b/OnnxStack.Core/Video/VideoHelper.cs @@ -0,0 +1,310 @@ +using FFMpegCore; +using OnnxStack.Core.Config; +using OnnxStack.Core.Image; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.Core.Video +{ + + public static class VideoHelper + { + private static OnnxStackConfig _configuration = new OnnxStackConfig(); + + /// + /// Sets the configuration. + /// + /// The configuration. + public static void SetConfiguration(OnnxStackConfig configuration) + { + _configuration = configuration; + } + + /// + /// Writes the video frames to file. + /// + /// The onnx video. + /// The filename. + /// The cancellation token. + public static async Task WriteVideoFramesAsync(OnnxVideo onnxVideo, string filename, CancellationToken cancellationToken = default) + { + await WriteVideoFramesAsync(onnxVideo.Frames, filename, onnxVideo.FrameRate, onnxVideo.AspectRatio, cancellationToken); + } + + + /// + /// Writes the video frames to file. + /// + /// The onnx images. + /// The filename. + /// The frame rate. + /// The cancellation token. + public static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate = 15, CancellationToken cancellationToken = default) + { + var firstImage = onnxImages.First(); + var aspectRatio = (double)firstImage.Width / firstImage.Height; + await WriteVideoFramesAsync(onnxImages, filename, frameRate, aspectRatio, cancellationToken); + } + + + /// + /// Writes the video frames to file. + /// + /// The onnx images. + /// The filename. + /// The frame rate. + /// The aspect ratio. + /// The cancellation token. + private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, CancellationToken cancellationToken = default) + { + using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio)) + { + // Start FFMPEG + videoWriter.Start(); + foreach (var image in onnxImages) + { + // Write each frame to the input stream of FFMPEG + await Task.Yield(); + await videoWriter.StandardInput.BaseStream.WriteAsync(image.GetImageBytes(), cancellationToken); + } + + // Done close stream and wait for app to process + videoWriter.StandardInput.BaseStream.Close(); + await videoWriter.WaitForExitAsync(cancellationToken); + } + } + + + /// + /// Reads the video information. + /// + /// The video bytes. + /// + public static async Task ReadVideoInfoAsync(byte[] videoBytes) + { + using (var memoryStream = new MemoryStream(videoBytes)) + { + var result = await FFProbe.AnalyseAsync(memoryStream).ConfigureAwait(false); + return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate); + } + } + + + /// + /// Reads the video frames. + /// + /// The video bytes. + /// The frame rate. + /// The cancellation token. + /// + public static async Task> ReadVideoFramesAsync(byte[] videoBytes, float frameRate = 15, CancellationToken cancellationToken = default) + { + return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken) + .Select(x => new OnnxImage(x)) + .ToListAsync(cancellationToken); + } + + + #region Private Members + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video data. + /// The FPS. + /// The cancellation token. + /// + /// Invalid PNG header + private static async IAsyncEnumerable CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string tempVideoPath = GetTempFilename(); + try + { + await File.WriteAllBytesAsync(tempVideoPath, videoData, cancellationToken); + using (var ffmpegProcess = CreateReader(tempVideoPath, fps)) + { + // Start FFMPEG + ffmpegProcess.Start(); + + // FFMPEG output stream + var processOutputStream = ffmpegProcess.StandardOutput.BaseStream; + + // Buffer to hold the current image + var buffer = new byte[20480000]; + + var currentIndex = 0; + while (!cancellationToken.IsCancellationRequested) + { + // Reset the index new PNG + currentIndex = 0; + + // Read the PNG Header + if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0) + break; + + currentIndex += 8;// header length + + if (!IsImageHeader(buffer)) + throw new Exception("Invalid PNG header"); + + // loop through each chunk + while (true) + { + // Read the chunk header + await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken); + + var chunkIndex = currentIndex; + currentIndex += 12; // Chunk header length + + // Get the chunk's content size in bytes from the header we just read + var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3]; + if (totalSize > 0) + { + var totalRead = 0; + while (totalRead < totalSize) + { + int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken); + currentIndex += read; + totalRead += read; + } + continue; + } + + // If the size is 0 and is the end of the image + if (totalSize == 0 && IsImageEnd(buffer, chunkIndex)) + break; + } + + yield return buffer[..currentIndex]; + } + + if (cancellationToken.IsCancellationRequested) + ffmpegProcess.Kill(); + } + } + finally + { + DeleteTempFile(tempVideoPath); + } + } + + + /// + /// Gets the temporary filename. + /// + /// + private static string GetTempFilename() + { + if (!Directory.Exists(_configuration.TempPath)) + Directory.CreateDirectory(_configuration.TempPath); + + return Path.Combine(_configuration.TempPath, $"{Path.GetFileNameWithoutExtension(Path.GetRandomFileName())}.mp4"); + } + + + /// + /// Deletes the temporary file. + /// + /// The filename. + private static void DeleteTempFile(string filename) + { + try + { + if (File.Exists(filename)) + File.Delete(filename); + } + catch (Exception) + { + // File in use, Log + } + } + + + /// + /// Creates FFMPEG video reader process. + /// + /// The input file. + /// The FPS. + /// + private static Process CreateReader(string inputFile, float fps) + { + var ffmpegProcess = new Process(); + ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; + ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -i \"{inputFile}\" -c:v png -r {fps} -f image2pipe -"; + ffmpegProcess.StartInfo.RedirectStandardOutput = true; + ffmpegProcess.StartInfo.UseShellExecute = false; + ffmpegProcess.StartInfo.CreateNoWindow = true; + return ffmpegProcess; + } + + + /// + /// Creates FFMPEG video writer process. + /// + /// The output file. + /// The FPS. + /// The aspect ratio. + /// + private static Process CreateWriter(string outputFile, float fps, double aspectRatio) + { + var ffmpegProcess = new Process(); + ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; + ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v libx264 -movflags +faststart -vf format=yuv420p -aspect {aspectRatio} {outputFile}"; + ffmpegProcess.StartInfo.RedirectStandardInput = true; + ffmpegProcess.StartInfo.UseShellExecute = false; + ffmpegProcess.StartInfo.CreateNoWindow = true; + return ffmpegProcess; + } + + + /// + /// Determines whether we are at the start of a PNG image in the specified buffer. + /// + /// The buffer. + /// The offset. + /// + /// true if the start of a PNG image sequence is detectedfalse. + /// + private static bool IsImageHeader(byte[] buffer) + { + // PNG Header http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html#PNG-file-signature + if (buffer[0] != 0x89 + || buffer[1] != 0x50 + || buffer[2] != 0x4E + || buffer[3] != 0x47 + || buffer[4] != 0x0D + || buffer[5] != 0x0A + || buffer[6] != 0x1A + || buffer[7] != 0x0A) + return false; + + return true; + } + + + /// + /// Determines whether we are at the end of a PNG image in the specified buffer. + /// + /// The buffer. + /// The offset. + /// + /// true if the end of a PNG image sequence is detectedfalse. + /// + private static bool IsImageEnd(byte[] buffer, int offset) + { + return buffer[offset + 4] == 0x49 // I + && buffer[offset + 5] == 0x45 // E + && buffer[offset + 6] == 0x4E // N + && buffer[offset + 7] == 0x44; // D + } + } + + #endregion +} diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs index 17f3fb39..1a67870f 100644 --- a/OnnxStack.Core/Video/VideoInfo.cs +++ b/OnnxStack.Core/Video/VideoInfo.cs @@ -2,5 +2,14 @@ namespace OnnxStack.Core.Video { - public record VideoInfo(int Width, int Height, TimeSpan Duration, float FPS); + public sealed record VideoInfo(TimeSpan Duration, float FrameRate) + { + public VideoInfo(int height, int width, TimeSpan duration, float frameRate) : this(duration, frameRate) + { + Height = height; + Width = width; + } + public int Height { get; set; } + public int Width { get; set; } + } } diff --git a/OnnxStack.Core/Video/VideoInput.cs b/OnnxStack.Core/Video/VideoInput.cs deleted file mode 100644 index 4dc299e4..00000000 --- a/OnnxStack.Core/Video/VideoInput.cs +++ /dev/null @@ -1,118 +0,0 @@ -using Microsoft.Extensions.Primitives; -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Config; -using OnnxStack.Core.Services; -using System.IO; -using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.Core.Video -{ - public class VideoInput - { - /// - /// Initializes a new instance of the class. - /// - public VideoInput() { } - - /// - /// Initializes a new instance of the class. - /// - /// The video bytes. - public VideoInput(byte[] videoBytes) => VideoBytes = videoBytes; - - /// - /// Initializes a new instance of the class. - /// - /// The video stream. - public VideoInput(Stream videoStream) => VideoStream = videoStream; - - /// - /// Initializes a new instance of the class. - /// - /// The video tensor. - public VideoInput(DenseTensor videoTensor) => VideoTensor = videoTensor; - - /// - /// Initializes a new instance of the class. - /// - /// The video frames. - public VideoInput(VideoFrames videoFrames) => VideoFrames = videoFrames; - - - /// - /// Gets the video bytes. - /// - [JsonIgnore] - public byte[] VideoBytes { get; set; } - - - /// - /// Gets the video stream. - /// - [JsonIgnore] - public Stream VideoStream { get; set; } - - - /// - /// Gets the video tensor. - /// - [JsonIgnore] - public DenseTensor VideoTensor { get; set; } - - - /// - /// Gets or sets the video frames. - /// - [JsonIgnore] - public VideoFrames VideoFrames { get; set; } - - - /// - /// Gets a value indicating whether this instance has video. - /// - /// - /// true if this instance has video; otherwise, false. - /// - [JsonIgnore] - public bool HasVideo => VideoBytes != null - || VideoStream != null - || VideoTensor != null - || VideoFrames != null; - - - - /// - /// Create a VideoInput from file - /// - /// The video file. - /// The target FPS. - /// The configuration. - /// The cancellation token. - /// - public static async Task FromFileAsync(string videoFile, float? targetFPS = default, OnnxStackConfig config = default, CancellationToken cancellationToken = default) - { - var videoBytes = await File.ReadAllBytesAsync(videoFile, cancellationToken); - var videoService = new VideoService(config ?? new OnnxStackConfig()); - var videoFrames = await videoService.CreateFramesAsync(videoBytes, targetFPS, cancellationToken); - return new VideoInput(videoFrames); - } - - - /// - /// Saves the video file - /// - /// The video tensor. - /// The video file. - /// The target FPS. - /// The configuration. - /// The cancellation token. - public static async Task SaveFileAsync(DenseTensor videoTensor, string videoFile, float targetFPS, OnnxStackConfig config = default, CancellationToken cancellationToken = default) - { - var videoService = new VideoService(config ?? new OnnxStackConfig()); - var videoOutput = await videoService.CreateVideoAsync(videoTensor, targetFPS, cancellationToken); - await File.WriteAllBytesAsync(videoFile, videoOutput.Data, cancellationToken); - } - } -} diff --git a/OnnxStack.Core/Video/VideoOutput.cs b/OnnxStack.Core/Video/VideoOutput.cs deleted file mode 100644 index db4195b3..00000000 --- a/OnnxStack.Core/Video/VideoOutput.cs +++ /dev/null @@ -1,4 +0,0 @@ -namespace OnnxStack.Core.Video -{ - public record VideoOutput(byte[] Data, VideoInfo Info); -} diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs index 1d80ef8e..7c2f99e4 100644 --- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs +++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs @@ -5,6 +5,7 @@ using OnnxStack.Core.Model; using OnnxStack.Core.Video; using OnnxStack.FeatureExtractor.Common; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading; @@ -88,16 +89,16 @@ public async Task RunAsync(OnnxImage inputImage, CancellationToken ca /// /// The input video. /// - public async Task RunAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default) + public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Extracting video features..."); var metadata = await _featureExtractorModel.GetMetadataAsync(); cancellationToken.ThrowIfCancellationRequested(); - foreach (var videoFrame in videoFrames.Frames) + var frames = new List(); + foreach (var videoFrame in video.Frames) { - var image = new OnnxImage(videoFrame.Frame); - var controlImage = await image.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); + var controlImage = await videoFrame.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { inferenceParameters.AddInputTensor(controlImage); @@ -113,12 +114,12 @@ public async Task RunAsync(VideoFrames videoFrames, CancellationTok resultTensor.NormalizeMinMax(); var maskImage = resultTensor.ToImageMask(); - videoFrame.ExtraFrame = maskImage; + frames.Add(maskImage); } } } _logger?.LogEnd("Extracting video features complete.", timestamp); - return videoFrames; + return new OnnxVideo(video.Info, frames); } diff --git a/OnnxStack.FeatureExtractor/README.md b/OnnxStack.FeatureExtractor/README.md index 7b4ad0c3..fac1c9ff 100644 --- a/OnnxStack.FeatureExtractor/README.md +++ b/OnnxStack.FeatureExtractor/README.md @@ -14,11 +14,11 @@ https://huggingface.co/julienkay/sentis-MiDaS ## OpenPose (TODO) https://huggingface.co/axodoxian/controlnet_onnx/resolve/main/annotators/openpose.onnx -# Basic Example +# Image Example ```csharp // Load Input Image -var inputImage = await InputImage.FromFileAsync("Input.png"); +var inputImage = await OnnxImage.FromFileAsync("Input.png"); // Load Pipeline var pipeline = FeatureExtractorPipeline.CreatePipeline("canny.onnx"); @@ -27,7 +27,26 @@ var pipeline = FeatureExtractorPipeline.CreatePipeline("canny.onnx"); var imageFeature = await pipeline.RunAsync(inputImage); // Save Image -await imageFeature.Image.SaveAsPngAsync("Result.png"); +await imageFeature.Image.SaveAsync("Result.png"); + +//Unload +await pipeline.UnloadAsync(); + ``` + + # Video Example +```csharp + +// Load Input Video +var inputVideo = await OnnxVideo.FromFileAsync("Input.mp4"); + +// Load Pipeline +var pipeline = FeatureExtractorPipeline.CreatePipeline("canny.onnx"); + +// Run Pipeline +var videoFeature = await pipeline.RunAsync(inputVideo); + +// Save Video +await videoFeature.SaveAsync("Result.mp4"); //Unload await pipeline.UnloadAsync(); diff --git a/OnnxStack.StableDiffusion/Common/BatchResult.cs b/OnnxStack.StableDiffusion/Common/BatchResult.cs index b682ac2b..192d2480 100644 --- a/OnnxStack.StableDiffusion/Common/BatchResult.cs +++ b/OnnxStack.StableDiffusion/Common/BatchResult.cs @@ -1,7 +1,7 @@ -using OnnxStack.Core.Image; +using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.StableDiffusion.Config; namespace OnnxStack.StableDiffusion.Common { - public record BatchResult(SchedulerOptions SchedulerOptions, OnnxImage ImageResult); + public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor Result); } diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index 49f368dc..2544b67d 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -22,10 +22,8 @@ public class PromptOptions public OnnxImage InputContolImage { get; set; } - public VideoInput InputVideo { get; set; } - - public float VideoInputFPS { get; set; } - public float VideoOutputFPS { get; set; } + public OnnxVideo InputVideo { get; set; } + public OnnxVideo InputContolVideo { get; set; } public bool HasInputVideo => InputVideo?.HasVideo ?? false; public bool HasInputImage => InputImage?.HasImage ?? false; diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs index fcc5e976..42af89ac 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs @@ -1,4 +1,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; @@ -75,5 +77,55 @@ public interface IPipeline /// The cancellation token. /// IAsyncEnumerable RunBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the pipeline returning the result as an OnnxImage. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + Task GenerateImageAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the batch pipeline returning the result as an OnnxImage. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the pipeline returning the result as an OnnxVideo. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + Task GenerateVideoAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the batch pipeline returning the result as an OnnxVideo. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs index 26c0d6b2..97aa4ad4 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs @@ -2,6 +2,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Diffusers; @@ -95,6 +96,8 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger) public abstract Task> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// /// Runs the pipeline batch. /// @@ -108,6 +111,56 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger) public abstract IAsyncEnumerable RunBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + /// + /// Runs the pipeline returning the result as an OnnxImage. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public abstract Task GenerateImageAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the batch pipeline returning the result as an OnnxImage. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public abstract IAsyncEnumerable GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the pipeline returning the result as an OnnxVideo. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public abstract Task GenerateVideoAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + + /// + /// Runs the batch pipeline returning the result as an OnnxVideo. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public abstract IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// /// Creates the diffuser. /// @@ -146,42 +199,38 @@ protected async Task> DiffuseImageAsync(IDiffuser diffuser, P /// The progress callback. /// The cancellation token. /// - protected async Task> DiffuseVideoAsync(IDiffuser diffuser, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected async IAsyncEnumerable> DiffuseVideoAsync(IDiffuser diffuser, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var diffuseTime = _logger?.LogBegin("Video Diffuser starting..."); var frameIndex = 0; - DenseTensor videoTensor = null; - var videoFrames = promptOptions.InputVideo.VideoFrames.Frames; + var videoFrames = promptOptions.InputVideo.Frames; var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex); foreach (var videoFrame in videoFrames) { - frameIndex++; - // byte[] videoFrame = videoFrames[i].Frame; if (promptOptions.DiffuserType == DiffuserType.ControlNet || promptOptions.DiffuserType == DiffuserType.ControlNetImage) { // ControlNetImage uses frame as input image if (promptOptions.DiffuserType == DiffuserType.ControlNetImage) - promptOptions.InputImage = new OnnxImage(videoFrame.Frame); + promptOptions.InputImage = videoFrame; - promptOptions.InputContolImage = videoFrame.ExtraFrame; + promptOptions.InputContolImage = promptOptions.InputContolVideo?.GetFrame(frameIndex); } else { - promptOptions.InputImage = new OnnxImage(videoFrame.Frame); + promptOptions.InputImage = videoFrame; } var frameResultTensor = await diffuser.DiffuseAsync(promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken); // Frame Progress - ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor); + ReportBatchProgress(progressCallback, ++frameIndex, videoFrames.Count, frameResultTensor); // Concatenate frame - videoTensor = videoTensor.Concatenate(frameResultTensor); + yield return frameResultTensor; } _logger?.LogEnd($"Video Diffuser complete", diffuseTime); - return videoTensor; } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs index 5d3b809f..227defaf 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs @@ -4,6 +4,7 @@ using OnnxStack.Core.Config; using OnnxStack.Core.Image; using OnnxStack.Core.Model; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Diffusers; @@ -154,7 +155,7 @@ public override void ValidateInputs(PromptOptions promptOptions, SchedulerOption /// - /// Runs the pipeline. + /// Runs the pipeline returning the tensor result. /// /// The prompt options. /// The scheduler options. @@ -164,15 +165,12 @@ public override void ValidateInputs(PromptOptions promptOptions, SchedulerOption /// public override async Task> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default) { - // Create random seed if none was set - schedulerOptions ??= _defaultSchedulerOptions; - schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); - var diffuseTime = _logger?.LogBegin("Diffuser starting..."); - _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); // Check guidance - var performGuidance = ShouldPerformGuidance(schedulerOptions); + var performGuidance = ShouldPerformGuidance(options); // Process prompts var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); @@ -181,9 +179,18 @@ public override async Task> RunAsync(PromptOptions promptOpti var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); // Diffuse - var tensorResult = promptOptions.HasInputVideo - ? await DiffuseVideoAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken) - : await DiffuseImageAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + var tensorResult = default(DenseTensor); + if (promptOptions.HasInputVideo) + { + await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken)) + { + tensorResult = tensorResult.Concatenate(frameTensor); + } + } + else + { + tensorResult = await DiffuseImageAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + } _logger?.LogEnd($"Diffuser complete", diffuseTime); return tensorResult; @@ -202,22 +209,103 @@ public override async Task> RunAsync(PromptOptions promptOpti /// public override async IAsyncEnumerable RunBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // Create random seed if none was set - schedulerOptions ??= _defaultSchedulerOptions; - schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); + var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting..."); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); + _logger?.Log($"BatchType: {batchOptions.BatchType}, ValueFrom: {batchOptions.ValueFrom}, ValueTo: {batchOptions.ValueTo}, Increment: {batchOptions.Increment}"); + + // Check guidance + var performGuidance = ShouldPerformGuidance(options); + + // Process prompts + var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); + + // Generate batch options + var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, options); + + // Create Diffuser + var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); + + // Diffuse + var batchIndex = 1; + var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex); + foreach (var batchSchedulerOption in batchSchedulerOptions) + { + var tensorResult = default(DenseTensor); + if (promptOptions.HasInputVideo) + { + await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken)) + { + tensorResult = tensorResult.Concatenate(frameTensor); + } + } + else + { + tensorResult = await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + } + yield return new BatchResult(batchSchedulerOption, tensorResult); + batchIndex++; + } + + _logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime); + } + + /// + /// Runs the pipeline returning the result as an OnnxImage. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public override async Task GenerateImageAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default) + { + var diffuseTime = _logger?.LogBegin("Diffuser starting..."); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); + + // Check guidance + var performGuidance = ShouldPerformGuidance(options); + + // Process prompts + var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); + + // Create Diffuser + var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); + + var imageResult = await DiffuseImageAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + + return new OnnxImage(imageResult); + } + + + /// + /// Runs the batch pipeline returning the result as an OnnxImage. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public override async IAsyncEnumerable GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting..."); - _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}"); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); _logger?.Log($"BatchType: {batchOptions.BatchType}, ValueFrom: {batchOptions.ValueFrom}, ValueTo: {batchOptions.ValueTo}, Increment: {batchOptions.Increment}"); // Check guidance - var performGuidance = ShouldPerformGuidance(schedulerOptions); + var performGuidance = ShouldPerformGuidance(options); // Process prompts var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); // Generate batch options - var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, schedulerOptions); + var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, options); // Create Diffuser var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); @@ -227,11 +315,88 @@ public override async IAsyncEnumerable RunBatchAsync(BatchOptions b var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex); foreach (var batchSchedulerOption in batchSchedulerOptions) { - var tensorResult = promptOptions.HasInputVideo - ? await DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken) - : await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken); + var tensorResult = await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken); + yield return new OnnxImage(tensorResult); + batchIndex++; + } + + _logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime); + } + + + /// + /// Runs the pipeline returning the result as an OnnxVideo. + /// + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public override async Task GenerateVideoAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default) + { + var diffuseTime = _logger?.LogBegin("Diffuser starting..."); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); + + // Check guidance + var performGuidance = ShouldPerformGuidance(options); + + // Process prompts + var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); + + // Create Diffuser + var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); + + var frames = new List(); + await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken)) + { + frames.Add(new OnnxImage(frameTensor)); + } + return new OnnxVideo(promptOptions.InputVideo.Info, frames); + } + + + /// + /// Runs the batch pipeline returning the result as an OnnxVideo. + /// + /// The batch options. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public override async IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting..."); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); + _logger?.Log($"BatchType: {batchOptions.BatchType}, ValueFrom: {batchOptions.ValueFrom}, ValueTo: {batchOptions.ValueTo}, Increment: {batchOptions.Increment}"); + + // Check guidance + var performGuidance = ShouldPerformGuidance(options); - yield return new BatchResult(batchSchedulerOption, new OnnxImage(tensorResult)); + // Process prompts + var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); + + // Generate batch options + var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, options); + + // Create Diffuser + var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); + + // Diffuse + var batchIndex = 1; + var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex); + foreach (var batchSchedulerOption in batchSchedulerOptions) + { + var frames = new List(); + await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken)) + { + frames.Add(new OnnxImage(frameTensor)); + } + yield return new OnnxVideo(promptOptions.InputVideo.Info, frames); batchIndex++; } @@ -239,6 +404,24 @@ public override async IAsyncEnumerable RunBatchAsync(BatchOptions b } + /// + /// Gets the scheduler options or the default scheduler options + /// + /// The scheduler options. + /// + private SchedulerOptions GetSchedulerOptionsOrDefault(SchedulerOptions schedulerOptions) + { + // Create random seed if none was set + if (schedulerOptions == null) + return _defaultSchedulerOptions with { Seed = Random.Shared.Next() }; + + if (schedulerOptions.Seed <= 0) + return schedulerOptions with { Seed = Random.Shared.Next() }; + + return schedulerOptions; + } + + /// /// Overrides the vae encoder with a custom implementation, Caller is responsible for model lifetime /// diff --git a/OnnxStack.UI/App.xaml.cs b/OnnxStack.UI/App.xaml.cs index 1de842c9..ea8a7a2e 100644 --- a/OnnxStack.UI/App.xaml.cs +++ b/OnnxStack.UI/App.xaml.cs @@ -2,9 +2,6 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using OnnxStack.Core; -using OnnxStack.Core.Services; -using OnnxStack.ImageUpscaler; -using OnnxStack.StableDiffusion.Config; using OnnxStack.UI.Dialogs; using OnnxStack.UI.Models; using OnnxStack.UI.Services; @@ -50,7 +47,6 @@ public App() builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); - builder.Services.AddSingleton(); // Build App diff --git a/OnnxStack.UI/Services/IUpscaleService.cs b/OnnxStack.UI/Services/IUpscaleService.cs index ed1712c6..84469c90 100644 --- a/OnnxStack.UI/Services/IUpscaleService.cs +++ b/OnnxStack.UI/Services/IUpscaleService.cs @@ -51,6 +51,6 @@ public interface IUpscaleService /// The model options. /// The video input. /// - Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default); + Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxVideo videoInput, CancellationToken cancellationToken = default); } } diff --git a/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs b/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs index 267b8c95..9aa095e3 100644 --- a/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs +++ b/OnnxStack.UI/UserControls/SchedulerControl.xaml.cs @@ -4,6 +4,7 @@ using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.ComponentModel; using System.Linq; @@ -24,7 +25,7 @@ public partial class SchedulerControl : UserControl, INotifyPropertyChanged /// Initializes a new instance of the class. public SchedulerControl() { - ValidSizes = new ObservableCollection(Constants.ValidSizes); + ValidSizes = new ObservableCollection(new [] { 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 }); NewSeedCommand = new RelayCommand(NewSeed); RandomSeedCommand = new RelayCommand(RandomSeed); ResetParametersCommand = new RelayCommand(ResetParameters); From 299c42193def95280e692f9e1602f0cce19ea2c9 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 16 Feb 2024 11:41:12 +1300 Subject: [PATCH 4/5] Update example UI --- OnnxStack.Core/Video/VideoHelper.cs | 34 +++- .../Services/IStableDiffusionService.cs | 19 +- OnnxStack.UI/Services/IUpscaleService.cs | 9 - .../Services/StableDiffusionService.cs | 176 ++++-------------- OnnxStack.UI/Services/UpscaleService.cs | 48 +---- .../UserControls/VideoInputControl.xaml.cs | 12 +- OnnxStack.UI/Views/ImageInpaintView.xaml.cs | 2 +- OnnxStack.UI/Views/ImageToImageView.xaml.cs | 2 +- OnnxStack.UI/Views/TextToImageView.xaml.cs | 2 +- OnnxStack.UI/Views/VideoToVideoView.xaml.cs | 44 ++--- 10 files changed, 106 insertions(+), 242 deletions(-) diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs index ccd07cd4..0b754733 100644 --- a/OnnxStack.Core/Video/VideoHelper.cs +++ b/OnnxStack.Core/Video/VideoHelper.cs @@ -63,6 +63,9 @@ public static async Task WriteVideoFramesAsync(IEnumerable onnxImages /// The cancellation token. private static async Task WriteVideoFramesAsync(IEnumerable onnxImages, string filename, float frameRate, double aspectRatio, CancellationToken cancellationToken = default) { + if (File.Exists(filename)) + File.Delete(filename); + using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio)) { // Start FFMPEG @@ -70,7 +73,6 @@ private static async Task WriteVideoFramesAsync(IEnumerable onnxImage foreach (var image in onnxImages) { // Write each frame to the input stream of FFMPEG - await Task.Yield(); await videoWriter.StandardInput.BaseStream.WriteAsync(image.GetImageBytes(), cancellationToken); } @@ -96,11 +98,23 @@ public static async Task ReadVideoInfoAsync(byte[] videoBytes) } + /// + /// Reads the video information. + /// + /// The filename. + /// + public static async Task ReadVideoInfoAsync(string filename) + { + var result = await FFProbe.AnalyseAsync(filename).ConfigureAwait(false); + return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate); + } + + /// /// Reads the video frames. /// /// The video bytes. - /// The frame rate. + /// The target frame rate. /// The cancellation token. /// public static async Task> ReadVideoFramesAsync(byte[] videoBytes, float frameRate = 15, CancellationToken cancellationToken = default) @@ -111,6 +125,22 @@ public static async Task> ReadVideoFramesAsync(byte[] videoBytes } + /// + /// Reads the video frames. + /// + /// The video bytes. + /// The target frame rate. + /// The cancellation token. + /// + public static async Task> ReadVideoFramesAsync(string filename, float frameRate = 15, CancellationToken cancellationToken = default) + { + var videoBytes = await File.ReadAllBytesAsync(filename, cancellationToken); + return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken) + .Select(x => new OnnxImage(x)) + .ToListAsync(cancellationToken); + } + + #region Private Members diff --git a/OnnxStack.UI/Services/IStableDiffusionService.cs b/OnnxStack.UI/Services/IStableDiffusionService.cs index 73c934c4..ae39e6a1 100644 --- a/OnnxStack.UI/Services/IStableDiffusionService.cs +++ b/OnnxStack.UI/Services/IStableDiffusionService.cs @@ -1,9 +1,8 @@ -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core.Image; +using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -68,18 +67,18 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. + /// Generates the StableDiffusion video using the prompt and options provided. /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. + /// The model. + /// The prompt. + /// The options. /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateVideoAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.UI/Services/IUpscaleService.cs b/OnnxStack.UI/Services/IUpscaleService.cs index 84469c90..36e2ae61 100644 --- a/OnnxStack.UI/Services/IUpscaleService.cs +++ b/OnnxStack.UI/Services/IUpscaleService.cs @@ -43,14 +43,5 @@ public interface IUpscaleService /// The input image. /// Task GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default); - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - Task> GenerateAsync(UpscaleModelSet modelOptions, OnnxVideo videoInput, CancellationToken cancellationToken = default); } } diff --git a/OnnxStack.UI/Services/StableDiffusionService.cs b/OnnxStack.UI/Services/StableDiffusionService.cs index 095a4ea6..1338132c 100644 --- a/OnnxStack.UI/Services/StableDiffusionService.cs +++ b/OnnxStack.UI/Services/StableDiffusionService.cs @@ -1,22 +1,18 @@ using Microsoft.Extensions.Logging; -using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Image; -using OnnxStack.Core.Services; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Pipelines; using OnnxStack.UI.Models; -using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.IO; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -28,7 +24,6 @@ namespace OnnxStack.UI.Services /// public sealed class StableDiffusionService : IStableDiffusionService { - private readonly IVideoService _videoService; private readonly ILogger _logger; private readonly OnnxStackUIConfig _configuration; private readonly Dictionary _pipelines; @@ -38,11 +33,10 @@ public sealed class StableDiffusionService : IStableDiffusionService /// Initializes a new instance of the class. /// /// The scheduler service. - public StableDiffusionService(OnnxStackUIConfig configuration, IVideoService videoService, ILogger logger) + public StableDiffusionService(OnnxStackUIConfig configuration, ILogger logger) { _logger = logger; _configuration = configuration; - _videoService = videoService; _pipelines = new Dictionary(); _controlNetSessions = new ConcurrentDictionary(); } @@ -64,8 +58,6 @@ public async Task LoadModelAsync(StableDiffusionModelSet model) } - - /// /// Unloads the model. /// @@ -95,6 +87,11 @@ public bool IsModelLoaded(StableDiffusionModelSet modelOptions) } + /// + /// Loads the model. + /// + /// + /// public async Task LoadControlNetModelAsync(ControlNetModelSet model) { if (_controlNetSessions.ContainsKey(model)) @@ -106,6 +103,12 @@ public async Task LoadControlNetModelAsync(ControlNetModelSet model) return _controlNetSessions.TryAdd(model, controlNet); } + + /// + /// Unloads the model. + /// + /// + /// public Task UnloadControlNetModelAsync(ControlNetModelSet model) { if (_controlNetSessions.Remove(model, out var controlNet)) @@ -115,6 +118,14 @@ public Task UnloadControlNetModelAsync(ControlNetModelSet model) return Task.FromResult(true); } + + /// + /// Determines whether the specified model is loaded + /// + /// The model options. + /// + /// true if the specified model is loaded; otherwise, false. + /// public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions) { return _controlNetSessions.ContainsKey(modelOptions); @@ -129,164 +140,55 @@ public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions) /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) - { - return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken) - .ContinueWith(t => new OnnxImage(t.Result), cancellationToken) - .ConfigureAwait(false); - } - - - - - - /// - /// Generates a batch of StableDiffusion image using the prompt and options provided. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress callback. - /// The cancellation token. - /// - public IAsyncEnumerable GenerateBatchAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) - { - return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken); - } - - - - - - - - - - /// - /// Runs the diffusion process - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The progress. - /// The cancellation token. - /// - /// - /// Pipeline not found or is unsupported - /// or - /// Diffuser not found or is unsupported - /// or - /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline. - /// - private async Task> DiffuseAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) + public async Task GenerateImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { - if (!_pipelines.TryGetValue(modelOptions.BaseModel, out var pipeline)) + if (!_pipelines.TryGetValue(model.BaseModel, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); var controlNet = default(ControlNetModel); - if (modelOptions.ControlNetModel is not null && !_controlNetSessions.TryGetValue(modelOptions.ControlNetModel, out controlNet)) + if (model.ControlNetModel is not null && !_controlNetSessions.TryGetValue(model.ControlNetModel, out controlNet)) throw new Exception("ControlNet not loaded"); - pipeline.ValidateInputs(promptOptions, schedulerOptions); + pipeline.ValidateInputs(prompt, options); - await GenerateInputVideoFrames(promptOptions, progress); - return await pipeline.RunAsync(promptOptions, schedulerOptions, controlNet, progress, cancellationToken); + return await pipeline.GenerateImageAsync(prompt, options, controlNet, progressCallback, cancellationToken); } /// - /// Runs the batch diffusion process. + /// Generates the StableDiffusion video using the prompt and options provided. /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The batch options. - /// The progress. + /// The model. + /// The prompt. + /// The options. + /// The progress callback. /// The cancellation token. /// /// /// Pipeline not found or is unsupported /// or - /// Diffuser not found or is unsupported - /// or - /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline. + /// ControlNet not loaded /// - private async IAsyncEnumerable DiffuseBatchAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task GenerateVideoAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { - if (!_pipelines.TryGetValue(modelOptions.BaseModel, out var pipeline)) + if (!_pipelines.TryGetValue(model.BaseModel, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); var controlNet = default(ControlNetModel); - if (modelOptions.ControlNetModel is not null && !_controlNetSessions.TryGetValue(modelOptions.ControlNetModel, out controlNet)) + if (model.ControlNetModel is not null && !_controlNetSessions.TryGetValue(model.ControlNetModel, out controlNet)) throw new Exception("ControlNet not loaded"); - pipeline.ValidateInputs(promptOptions, schedulerOptions); + pipeline.ValidateInputs(prompt, options); - await GenerateInputVideoFrames(promptOptions, progressCallback); - await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken)) - { - yield return result; - } + return await pipeline.GenerateVideoAsync(prompt, options, controlNet, progressCallback, cancellationToken); } /// - /// Generates the video result as bytes. + /// Creates the pipeline. /// - /// The options. - /// The video tensor. - /// The progress. - /// The cancellation token. - /// - private async Task GenerateVideoResultAsBytesAsync(DenseTensor videoTensor, float videoFPS, Action progress = null, CancellationToken cancellationToken = default) - { - progress?.Invoke(new DiffusionProgress("Generating Video Result...")); - var videoResult = await _videoService.CreateVideoAsync(videoTensor, videoFPS, cancellationToken); - return videoResult.Data; - } - - - /// - /// Generates the video result as stream. - /// - /// The options. - /// The video tensor. - /// The progress. - /// The cancellation token. + /// The model. /// - private async Task GenerateVideoResultAsStreamAsync(DenseTensor videoTensor, float videoFPS, Action progress = null, CancellationToken cancellationToken = default) - { - return new MemoryStream(await GenerateVideoResultAsBytesAsync(videoTensor, videoFPS, progress, cancellationToken)); - } - - - /// - /// Generates the input video frames. - /// - /// The prompt options. - /// The progress. - private async Task GenerateInputVideoFrames(PromptOptions promptOptions, Action progress) - { - if (!promptOptions.HasInputVideo || promptOptions.InputVideo.VideoFrames is not null) - return; - - if (promptOptions.VideoInputFPS == 0 || promptOptions.VideoOutputFPS == 0) - { - var videoInfo = await _videoService.GetVideoInfoAsync(promptOptions.InputVideo); - if (promptOptions.VideoInputFPS == 0) - promptOptions.VideoInputFPS = videoInfo.FPS; - - if (promptOptions.VideoOutputFPS == 0) - promptOptions.VideoOutputFPS = videoInfo.FPS; - } - - var videoFrame = await _videoService.CreateFramesAsync(promptOptions.InputVideo, promptOptions.VideoInputFPS); - progress?.Invoke(new DiffusionProgress($"Generating video frames @ {promptOptions.VideoInputFPS}fps")); - promptOptions.InputVideo.VideoFrames = videoFrame; - } - - private IPipeline CreatePipeline(StableDiffusionModelSet model) { return model.PipelineType switch diff --git a/OnnxStack.UI/Services/UpscaleService.cs b/OnnxStack.UI/Services/UpscaleService.cs index 419de2d2..36823225 100644 --- a/OnnxStack.UI/Services/UpscaleService.cs +++ b/OnnxStack.UI/Services/UpscaleService.cs @@ -1,10 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Image; -using OnnxStack.Core.Services; -using OnnxStack.Core.Video; using OnnxStack.FeatureExtractor.Pipelines; using OnnxStack.ImageUpscaler.Common; using System; @@ -16,7 +13,6 @@ namespace OnnxStack.UI.Services { public class UpscaleService : IUpscaleService { - private readonly IVideoService _videoService; private readonly ILogger _logger; private readonly Dictionary _pipelines; @@ -26,9 +22,8 @@ public class UpscaleService : IUpscaleService /// The configuration. /// The model service. /// The image service. - public UpscaleService(IVideoService videoService) + public UpscaleService() { - _videoService = videoService; _pipelines = new Dictionary(); } @@ -90,29 +85,6 @@ public async Task GenerateAsync(UpscaleModelSet modelOptions, OnnxIma } - - - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task> GenerateAsync(UpscaleModelSet modelOptions, VideoInput videoInput, CancellationToken cancellationToken = default) - { - var videoInfo = await _videoService.GetVideoInfoAsync(videoInput); - var tensorFrames = await GenerateInternalAsync(modelOptions, videoInput, videoInfo, cancellationToken); - - DenseTensor videoResult = default; - foreach (var tensorFrame in tensorFrames) - { - cancellationToken.ThrowIfCancellationRequested(); - videoResult = videoResult.Concatenate(tensorFrame); - } - return videoResult; - } - - /// /// Generates an upscaled image of the source provided. /// @@ -127,24 +99,6 @@ private async Task> GenerateInternalAsync(UpscaleModelSet mod } - /// - /// Generates the upscaled video. - /// - /// The model options. - /// The video input. - /// - public async Task>> GenerateInternalAsync(UpscaleModelSet modelSet, VideoInput videoInput, VideoInfo videoInfo, CancellationToken cancellationToken) - { - if (!_pipelines.TryGetValue(modelSet, out var pipeline)) - throw new Exception("Pipeline not found or is unsupported"); - - return new List>(); - } - - - - - private ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet) { return ImageUpscalePipeline.CreatePipeline(modelSet, _logger); diff --git a/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs b/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs index 74a1330b..26492d8c 100644 --- a/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs +++ b/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs @@ -1,5 +1,5 @@ using Microsoft.Win32; -using OnnxStack.Core.Services; +using OnnxStack.Core.Video; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System; @@ -15,7 +15,6 @@ namespace OnnxStack.UI.UserControls { public partial class VideoInputControl : UserControl, INotifyPropertyChanged { - private readonly IVideoService _videoService; private bool _isPlaying = false; /// @@ -23,9 +22,6 @@ public partial class VideoInputControl : UserControl, INotifyPropertyChanged /// public VideoInputControl() { - if (!DesignerProperties.GetIsInDesignMode(this)) - _videoService = App.GetService(); - LoadVideoCommand = new AsyncRelayCommand(LoadVideo); ClearVideoCommand = new AsyncRelayCommand(ClearVideo); InitializeComponent(); @@ -119,7 +115,7 @@ private async Task LoadVideo() if (openFileDialog.ShowDialog() == true) { var videoBytes = await File.ReadAllBytesAsync(openFileDialog.FileName); - var videoInfo = await _videoService.GetVideoInfoAsync(videoBytes); + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoBytes); VideoResult = new VideoInputModel { FileName = openFileDialog.FileName, @@ -127,8 +123,8 @@ private async Task LoadVideo() VideoBytes = videoBytes }; HasVideoResult = true; - PromptOptions.VideoInputFPS = videoInfo.FPS; - PromptOptions.VideoOutputFPS = videoInfo.FPS; + PromptOptions.VideoInputFPS = videoInfo.FrameRate; + PromptOptions.VideoOutputFPS = videoInfo.FrameRate; } } diff --git a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs index 1821680f..66e01a33 100644 --- a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs +++ b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs @@ -220,7 +220,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateImageAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { diff --git a/OnnxStack.UI/Views/ImageToImageView.xaml.cs b/OnnxStack.UI/Views/ImageToImageView.xaml.cs index 9d394088..d67c0ad2 100644 --- a/OnnxStack.UI/Views/ImageToImageView.xaml.cs +++ b/OnnxStack.UI/Views/ImageToImageView.xaml.cs @@ -208,7 +208,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateImageAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { diff --git a/OnnxStack.UI/Views/TextToImageView.xaml.cs b/OnnxStack.UI/Views/TextToImageView.xaml.cs index 05df66a1..a6cba29e 100644 --- a/OnnxStack.UI/Views/TextToImageView.xaml.cs +++ b/OnnxStack.UI/Views/TextToImageView.xaml.cs @@ -181,7 +181,7 @@ private async Task Generate() try { var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateImageAsync(new ModelOptions(_selectedModel.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultImage = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultImage != null) { diff --git a/OnnxStack.UI/Views/VideoToVideoView.xaml.cs b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs index 02af02c0..f3b28768 100644 --- a/OnnxStack.UI/Views/VideoToVideoView.xaml.cs +++ b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs @@ -1,15 +1,12 @@ using Microsoft.Extensions.Logging; using OnnxStack.Core.Image; -using OnnxStack.Core.Services; using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; -using OnnxStack.StableDiffusion.Helpers; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; -using SixLabors.ImageSharp.PixelFormats; -using SixLabors.ImageSharp; +using OnnxStack.UI.Services; using System; using System.Collections.Generic; using System.Collections.ObjectModel; @@ -23,7 +20,6 @@ using System.Windows.Controls; using System.Windows.Media.Imaging; using System.Windows.Threading; -using OnnxStack.UI.Services; namespace OnnxStack.UI.Views { @@ -34,7 +30,7 @@ public partial class VideoToVideoView : UserControl, INavigatable, INotifyProper { private readonly ILogger _logger; private readonly IStableDiffusionService _stableDiffusionService; - private readonly IVideoService _videoService; + private bool _hasResult; private int _progressMax; @@ -51,7 +47,7 @@ public partial class VideoToVideoView : UserControl, INavigatable, INotifyProper private PromptOptionsModel _promptOptionsModel; private SchedulerOptionsModel _schedulerOptions; private CancellationTokenSource _cancelationTokenSource; - private VideoFrames _videoFrames; + private OnnxVideo _videoFrames; private BitmapImage _previewSource; private BitmapImage _previewResult; @@ -63,7 +59,6 @@ public VideoToVideoView() if (!DesignerProperties.GetIsInDesignMode(this)) { _logger = App.GetService>(); - _videoService = App.GetService(); _stableDiffusionService = App.GetService(); } @@ -217,15 +212,16 @@ private async Task Generate() try { var schedulerOptions = SchedulerOptions.ToSchedulerOptions(); - if (_videoFrames is null || _videoFrames.Info.FPS != PromptOptions.VideoInputFPS) + if (_videoFrames is null || _videoFrames.Info.FrameRate != PromptOptions.VideoInputFPS) { ProgressText = $"Generating video frames @ {PromptOptions.VideoInputFPS}fps"; - _videoFrames = await _videoService.CreateFramesAsync(_inputVideo.VideoBytes, PromptOptions.VideoInputFPS, _cancelationTokenSource.Token); + var frames = await VideoHelper.ReadVideoFramesAsync(_inputVideo.VideoBytes, PromptOptions.VideoInputFPS, _cancelationTokenSource.Token); + _videoFrames = new OnnxVideo(_inputVideo.VideoInfo with { FrameRate = PromptOptions.VideoInputFPS }, frames); } var promptOptions = GetPromptOptions(PromptOptions, _videoFrames); var timestamp = Stopwatch.GetTimestamp(); - var result = await _stableDiffusionService.GenerateAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var result = await _stableDiffusionService.GenerateVideoAsync(new ModelOptions(_selectedModel.ModelSet, _selectedControlNetModel?.ModelSet), promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); var resultVideo = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); if (resultVideo != null) { @@ -324,7 +320,7 @@ private void Reset() } - private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, VideoFrames videoFrames) + private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, OnnxVideo videoFrames) { var diffuserType = DiffuserType.ImageToImage; if (_selectedModel.IsControlNet) @@ -339,9 +335,7 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Vi Prompt = promptOptionsModel.Prompt, NegativePrompt = promptOptionsModel.NegativePrompt, DiffuserType = diffuserType, - InputVideo = new VideoInput(videoFrames), - VideoInputFPS = promptOptionsModel.VideoInputFPS, - VideoOutputFPS = promptOptionsModel.VideoOutputFPS, + InputVideo = videoFrames }; } @@ -354,17 +348,15 @@ private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, Vi /// The scheduler options. /// The timestamp. /// - private async Task GenerateResultAsync(OnnxImage onnxImage, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + private async Task GenerateResultAsync(OnnxVideo onnxVideo, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) { - var videoBytes = onnxImage.GetImageBytes(); - - var tempVideoFile = Path.Combine(".temp", $"VideoToVideo.mp4"); - await File.WriteAllBytesAsync(tempVideoFile, videoBytes); - var videoInfo = await _videoService.GetVideoInfoAsync(videoBytes); + var tempVideoFile = Path.Combine(".temp", $"VideoToVideo.mp4"); + await onnxVideo.SaveAsync(tempVideoFile); + var videoBytes = await File.ReadAllBytesAsync(tempVideoFile); var videoResult = new VideoInputModel { FileName = tempVideoFile, - VideoInfo = videoInfo, + VideoInfo = onnxVideo.Info, VideoBytes = videoBytes }; return videoResult; @@ -389,7 +381,7 @@ private Action ProgressCallback() if (progress.BatchTensor is not null) { - PreviewResult = Utils.CreateBitmap(new OnnxImage( progress.BatchTensor).GetImageBytes()); + PreviewResult = Utils.CreateBitmap(new OnnxImage(progress.BatchTensor).GetImageBytes()); PreviewSource = UpdatePreviewFrame(progress.BatchValue - 1); ProgressText = $"Video Frame {progress.BatchValue} of {_videoFrames.Frames.Count} complete"; } @@ -411,10 +403,10 @@ public BitmapImage UpdatePreviewFrame(int index) { var frame = _videoFrames.Frames[index]; using (var memoryStream = new MemoryStream()) - using (var frameImage = SixLabors.ImageSharp.Image.Load(frame.Frame)) { - //frameImage.Resize(_schedulerOptions.Height, _schedulerOptions.Width); - frameImage.SaveAsPng(memoryStream); + var frameImage = frame.Clone(); + frameImage.Resize(_schedulerOptions.Height, _schedulerOptions.Width); + frameImage.Save(memoryStream); var image = new BitmapImage(); image.BeginInit(); image.CacheOption = BitmapCacheOption.OnLoad; From 16ddf2a6a5d49be6e7cc97b6c4f3379b8c499b56 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 16 Feb 2024 11:47:52 +1300 Subject: [PATCH 5/5] Update README --- OnnxStack.StableDiffusion/README.md | 30 +++++++++++++---------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/OnnxStack.StableDiffusion/README.md b/OnnxStack.StableDiffusion/README.md index e27989b4..f979af85 100644 --- a/OnnxStack.StableDiffusion/README.md +++ b/OnnxStack.StableDiffusion/README.md @@ -53,11 +53,10 @@ var pipeline = StableDiffusionPipeline.CreatePipeline("models\\stable-diffusion- var promptOptions = new PromptOptions { Prompt = "Photo of a cute dog." }; // Run Pipleine -var result = await pipeline.RunAsync(promptOptions); +var result = await pipeline.GenerateImageAsync(promptOptions); // Save image result -var image = result.ToImage(); -await image.SaveAsPngAsync("D:\\Results\\Image.png"); +await result.SaveAsync("D:\\Results\\Image.png"); // Unload Pipleine await pipeline.UnloadAsync(); @@ -86,8 +85,8 @@ var batchOptions = new BatchOptions await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions)) { // Save Image result - var image = result.ImageResult.ToImage(); - await image.SaveAsPngAsync($"Output_Batch_{result.SchedulerOptions.Seed}.png"); + var image = new OnnxImage(result.ImageResult); + await image.SaveAsync($"Output_Batch_{result.SchedulerOptions.Seed}.png"); } // Unload Pipleine @@ -107,7 +106,7 @@ Run Stable Diffusion process with an initial image as input var pipeline = StableDiffusionPipeline.CreatePipeline("models\\stable-diffusion-v1-5"); // Load Input Image -var inputImage = await InputImage.FromFileAsync("Input.png"); +var inputImage = await OnnxImage.FromFileAsync("Input.png"); // Set Prompt Options var promptOptions = new PromptOptions @@ -125,11 +124,10 @@ var schedulerOptions = pipeline.DefaultSchedulerOptions with }; // Run Pipleine -var result = await pipeline.RunAsync(promptOptions, schedulerOptions); +var result = await pipeline.GenerateImageAsync(promptOptions, schedulerOptions); // Save image result -var image = result.ToImage(); -await image.SaveAsPngAsync("Output_ImageToImage.png"); +await result.SaveAsync("Output_ImageToImage.png"); // Unload Pipleine await pipeline.UnloadAsync(); @@ -153,7 +151,7 @@ var pipeline = StableDiffusionPipeline.CreatePipeline("models\\stable_diffusion_ var controlNet = ControlNetModel.Create("models\\controlnet_onnx\\controlnet\\depth.onnx"); // Load Control Image -var controlImage = await InputImage.FromFileAsync("Input_Depth.png"); +var controlImage = await OnnxImage.FromFileAsync("Input_Depth.png"); // Set Prompt Options var promptOptions = new PromptOptions @@ -164,11 +162,10 @@ var promptOptions = new PromptOptions }; // Run Pipleine -var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet); +var result = await pipeline.GenerateImageAsync(promptOptions, controlNet: controlNet); // Save image result -var image = result.ToImage(); -await image.SaveAsPngAsync("Output_ControlNet.png"); +await result.SaveAsync("Output_ControlNet.png"); // Unload Pipleine await pipeline.UnloadAsync(); @@ -194,7 +191,7 @@ var pipeline = StableDiffusionPipeline.CreatePipeline("models\\stable-diffusion- // Load Video var targetFPS = 15; - var videoInput = await VideoInput.FromFileAsync("Input.gif", targetFPS); + var videoInput = await OnnxVideo.FromFileAsync("Input.gif", targetFPS); // Add text and video to prompt var promptOptions = new PromptOptions @@ -205,11 +202,10 @@ var pipeline = StableDiffusionPipeline.CreatePipeline("models\\stable-diffusion- }; // Run pipeline - var result = await pipeline.RunAsync(promptOptions, progressCallback: OutputHelpers.FrameProgressCallback); + var result = await pipeline.GenerateVideoAsync(promptOptions); // Save Video File - var outputFilename = Path.Combine(_outputDirectory, "Output_VideoToVideo.mp4"); - await VideoInput.SaveFileAsync(result, outputFilename, targetFPS); + await result.SaveAsync("Output_VideoToVideo.mp4"); // Unload Pipleine await pipeline.UnloadAsync();