This repository has been archived by the owner on Nov 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #96 from saddam213/ControlNet
LCM ControlNet, SDXL ControlNet, InstaFlow ControlNet
- Loading branch information
Showing
9 changed files
with
1,093 additions
and
29 deletions.
There are no files selected for viewing
208 changes: 208 additions & 0 deletions
208
OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.ML.OnnxRuntime.Tensors; | ||
using OnnxStack.Core; | ||
using OnnxStack.Core.Config; | ||
using OnnxStack.Core.Model; | ||
using OnnxStack.Core.Services; | ||
using OnnxStack.StableDiffusion.Common; | ||
using OnnxStack.StableDiffusion.Config; | ||
using OnnxStack.StableDiffusion.Enums; | ||
using OnnxStack.StableDiffusion.Helpers; | ||
using OnnxStack.StableDiffusion.Models; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics; | ||
using System.Linq; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow | ||
{ | ||
public class ControlNetDiffuser : InstaFlowDiffuser | ||
{ | ||
private readonly IControlNetImageService _controlNetImageService; | ||
|
||
/// <summary> | ||
/// Initializes a new instance of the <see cref="ControlNetDiffuser"/> class. | ||
/// </summary> | ||
/// <param name="configuration">The configuration.</param> | ||
/// <param name="onnxModelService">The onnx model service.</param> | ||
public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger<ControlNetDiffuser> logger) | ||
: base(onnxModelService, promptService, logger) | ||
{ | ||
_controlNetImageService = controlNetImageService; | ||
} | ||
|
||
/// <summary> | ||
/// Gets the type of the diffuser. | ||
/// </summary> | ||
public override DiffuserType DiffuserType => DiffuserType.ControlNet; | ||
|
||
|
||
/// <summary> | ||
/// Called on each Scheduler step. | ||
/// </summary> | ||
/// <param name="modelOptions">The model options.</param> | ||
/// <param name="promptOptions">The prompt options.</param> | ||
/// <param name="schedulerOptions">The scheduler options.</param> | ||
/// <param name="promptEmbeddings">The prompt embeddings.</param> | ||
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param> | ||
/// <param name="progressCallback">The progress callback.</param> | ||
/// <param name="cancellationToken">The cancellation token.</param> | ||
/// <returns></returns> | ||
/// <exception cref="NotImplementedException"></exception> | ||
protected override async Task<DenseTensor<float>> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default) | ||
{ | ||
// Get Scheduler | ||
using (var scheduler = GetScheduler(schedulerOptions)) | ||
{ | ||
// Get timesteps | ||
var timesteps = GetTimesteps(schedulerOptions, scheduler); | ||
|
||
// Create latent sample | ||
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps); | ||
|
||
// Get Model metadata | ||
var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet); | ||
|
||
// Get Model metadata | ||
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet); | ||
|
||
// Control Image | ||
var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions); | ||
|
||
// Get the distilled Timestep | ||
var distilledTimestep = 1.0f / timesteps.Count; | ||
|
||
// Loop though the timesteps | ||
var step = 0; | ||
foreach (var timestep in timesteps) | ||
{ | ||
step++; | ||
var stepTime = Stopwatch.GetTimestamp(); | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
// Create input tensor. | ||
var inputLatent = performGuidance ? latents.Repeat(2) : latents; | ||
var inputTensor = scheduler.ScaleInput(inputLatent, timestep); | ||
var timestepTensor = CreateTimestepTensor(timestep); | ||
var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage; | ||
var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale); | ||
|
||
var outputChannels = performGuidance ? 2 : 1; | ||
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); | ||
using (var inferenceParameters = new OnnxInferenceParameters(metadata)) | ||
{ | ||
inferenceParameters.AddInputTensor(inputTensor); | ||
inferenceParameters.AddInputTensor(timestepTensor); | ||
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); | ||
|
||
// ControlNet | ||
using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata)) | ||
{ | ||
controlNetParameters.AddInputTensor(inputTensor); | ||
controlNetParameters.AddInputTensor(timestepTensor); | ||
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); | ||
controlNetParameters.AddInputTensor(controlImage); | ||
if (controlNetMetadata.Inputs.Count == 5) | ||
controlNetParameters.AddInputTensor(conditioningScale); | ||
|
||
// Optimization: Pre-allocate device buffers for inputs | ||
foreach (var item in controlNetMetadata.Outputs) | ||
controlNetParameters.AddOutputBuffer(); | ||
|
||
// ControlNet inference | ||
var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters); | ||
|
||
// Add ControlNet outputs to Unet input | ||
foreach (var item in controlNetResults) | ||
inferenceParameters.AddInput(item); | ||
|
||
// Add output buffer | ||
inferenceParameters.AddOutputBuffer(outputDimension); | ||
|
||
// Unet inference | ||
var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters); | ||
using (var result = results.First()) | ||
{ | ||
var noisePred = result.ToDenseTensor(); | ||
|
||
// Perform guidance | ||
if (performGuidance) | ||
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale); | ||
|
||
// Scheduler Step | ||
latents = scheduler.Step(noisePred, timestep, latents).Result; | ||
|
||
latents = noisePred | ||
.MultiplyTensorByFloat(distilledTimestep) | ||
.AddTensors(latents); | ||
} | ||
} | ||
} | ||
|
||
ReportProgress(progressCallback, step, timesteps.Count, latents); | ||
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); | ||
} | ||
|
||
// Decode Latents | ||
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents); | ||
} | ||
} | ||
|
||
|
||
/// <summary> | ||
/// Gets the timesteps. | ||
/// </summary> | ||
/// <param name="options">The options.</param> | ||
/// <param name="scheduler">The scheduler.</param> | ||
/// <returns></returns> | ||
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler) | ||
{ | ||
return scheduler.Timesteps; | ||
} | ||
|
||
|
||
/// <summary> | ||
/// Prepares the input latents. | ||
/// </summary> | ||
/// <param name="model">The model.</param> | ||
/// <param name="prompt">The prompt.</param> | ||
/// <param name="options">The options.</param> | ||
/// <param name="scheduler">The scheduler.</param> | ||
/// <param name="timesteps">The timesteps.</param> | ||
/// <returns></returns> | ||
protected override Task<DenseTensor<float>> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps) | ||
{ | ||
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); | ||
} | ||
|
||
|
||
/// <summary> | ||
/// Creates the Conditioning Scale tensor. | ||
/// </summary> | ||
/// <param name="conditioningScale">The conditioningScale.</param> | ||
/// <returns></returns> | ||
protected static DenseTensor<double> CreateConditioningScaleTensor(float conditioningScale) | ||
{ | ||
return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 }); | ||
} | ||
|
||
|
||
/// <summary> | ||
/// Prepares the control image. | ||
/// </summary> | ||
/// <param name="promptOptions">The prompt options.</param> | ||
/// <param name="schedulerOptions">The scheduler options.</param> | ||
/// <returns></returns> | ||
protected async Task<DenseTensor<float>> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) | ||
{ | ||
var controlImage = promptOptions.InputContolImage; | ||
if (schedulerOptions.IsControlImageProcessingEnabled) | ||
{ | ||
controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width); | ||
} | ||
return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); | ||
} | ||
} | ||
} |
Oops, something went wrong.