From 81cf62724ffb8bb24d8dd248e3ef50cc1b1da3d4 Mon Sep 17 00:00:00 2001 From: trueai-org Date: Tue, 9 Jul 2024 16:58:14 +0800 Subject: [PATCH] add modal zoom action --- .../Controllers/SubmitController.cs | 66 ++++++++++- src/Midjourney.Infrastructure/Constants.cs | 5 + .../Dto/SubmitModalDTO.cs | 29 +++++ src/Midjourney.Infrastructure/Enums.cs | 5 +- .../LoadBalancer/DiscordInstanceImpl.cs | 14 ++- .../LoadBalancer/IDiscordInstance.cs | 2 +- .../Midjourney.Infrastructure.csproj | 2 + .../Resources/ApiParams/zoom.json | 24 ++++ .../Services/DiscordServiceImpl.cs | 24 ++++ .../Services/IDiscordService.cs | 10 ++ .../Services/ITaskService.cs | 8 ++ .../Services/TaskServiceImpl.cs | 105 +++++++++++++++--- 12 files changed, 272 insertions(+), 22 deletions(-) create mode 100644 src/Midjourney.Infrastructure/Dto/SubmitModalDTO.cs create mode 100644 src/Midjourney.Infrastructure/Resources/ApiParams/zoom.json diff --git a/src/Midjourney.API/Controllers/SubmitController.cs b/src/Midjourney.API/Controllers/SubmitController.cs index 1a55dc03..0624621a 100644 --- a/src/Midjourney.API/Controllers/SubmitController.cs +++ b/src/Midjourney.API/Controllers/SubmitController.cs @@ -322,6 +322,13 @@ public ActionResult Action([FromBody] SubmitActionDTO actionDTO) { task.Action = TaskAction.ACTION; } + // 自定义变焦 + // "MJ::CustomZoom::439f8670-52e8-4f57-afaa-fa08f6d6c751" + else if (actionDTO.CustomId.StartsWith("MJ::CustomZoom::")) + { + task.Action = TaskAction.ACTION; + task.Description = "Waiting for window confirm"; + } else { task.Action = TaskAction.ACTION; @@ -336,6 +343,59 @@ public ActionResult Action([FromBody] SubmitActionDTO actionDTO) return Ok(_taskService.SubmitAction(task, actionDTO)); } + /// + /// 提交 Modal + /// + /// + /// + [HttpPost("modal")] + public ActionResult Modal([FromBody] SubmitModalDTO actionDTO) + { + if (string.IsNullOrWhiteSpace(actionDTO.TaskId) || string.IsNullOrWhiteSpace(actionDTO.Prompt)) + { + return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "参数错误")); + } + + var targetTask = _taskStoreService.Get(actionDTO.TaskId); + if (targetTask == null) + { + return NotFound(SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效")); + } + + var prompt = actionDTO.Prompt; + var task = targetTask; + + var promptEn = TranslatePrompt(prompt); + try + { + BannedPromptUtils.CheckBanned(promptEn); + } + catch (BannedPromptException e) + { + return BadRequest(SubmitResultVO.Fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词") + .SetProperty("promptEn", promptEn) + .SetProperty("bannedWord", e.Message)); + } + + //List base64Array = imagineDTO.Base64Array ?? new List(); + + //List dataUrls = new List(); + //try + //{ + // dataUrls = ConvertUtils.ConvertBase64Array(base64Array); + //} + //catch (Exception e) + //{ + // _logger.LogError(e, "base64格式转换异常"); + // return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "base64格式错误")); + //} + + task.PromptEn = promptEn; + + return Ok(_taskService.SubmitModal(task, actionDTO)); + } + + /// /// 创建新的任务对象 /// @@ -353,10 +413,10 @@ private TaskInfo NewTask(BaseSubmitDTO baseDTO) var notifyHook = string.IsNullOrWhiteSpace(baseDTO.NotifyHook) ? _properties.NotifyHook : baseDTO.NotifyHook; task.SetProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook); - var none = SnowFlake.NextId(); - task.Nonce = none; + var nonce = SnowFlake.NextId(); + task.Nonce = nonce; - task.SetProperty(Constants.TASK_PROPERTY_NONCE, none); + task.SetProperty(Constants.TASK_PROPERTY_NONCE, nonce); return task; } diff --git a/src/Midjourney.Infrastructure/Constants.cs b/src/Midjourney.Infrastructure/Constants.cs index 942cd423..ae680b83 100644 --- a/src/Midjourney.Infrastructure/Constants.cs +++ b/src/Midjourney.Infrastructure/Constants.cs @@ -43,6 +43,11 @@ public static class Constants /// public const string TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId"; + /// + /// 执行动作 custom_id + /// + public const string TASK_PROPERTY_CUSTOM_ID = "custom_id"; + /// /// 标志. /// diff --git a/src/Midjourney.Infrastructure/Dto/SubmitModalDTO.cs b/src/Midjourney.Infrastructure/Dto/SubmitModalDTO.cs new file mode 100644 index 00000000..1e510784 --- /dev/null +++ b/src/Midjourney.Infrastructure/Dto/SubmitModalDTO.cs @@ -0,0 +1,29 @@ +using Swashbuckle.AspNetCore.Annotations; + +namespace Midjourney.Infrastructure.Dto +{ + /// + /// Imagine提交参数。 + /// + [SwaggerSchema("Imagine提交参数")] + public class SubmitModalDTO : BaseSubmitDTO + { + /// + /// 提示词。 + /// + [SwaggerSchema("提示词", Description = "Cat")] + public string Prompt { get; set; } + + /// + /// 任务ID。 + /// + [SwaggerSchema("任务ID", Description = "\"1320098173412546\"")] + public string TaskId { get; set; } + + /// + /// 局部重绘的蒙版base64 + /// + [SwaggerSchema("图片base64", Description = "data:image/png;base64,xxx")] + public string MaskBase64 { get; set; } + } +} \ No newline at end of file diff --git a/src/Midjourney.Infrastructure/Enums.cs b/src/Midjourney.Infrastructure/Enums.cs index 10dcab1c..d785a043 100644 --- a/src/Midjourney.Infrastructure/Enums.cs +++ b/src/Midjourney.Infrastructure/Enums.cs @@ -117,7 +117,10 @@ public enum TaskAction /// OUTPAINT, - //ZOOM + ///// + ///// 自定义变焦 + ///// + //ZOOM, //SHORTEN } diff --git a/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs b/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs index e542df3c..60e2c121 100644 --- a/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs +++ b/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs @@ -3,6 +3,7 @@ using Midjourney.Infrastructure.Util; using Serilog; using System.Collections.Concurrent; +using System.Threading.Tasks; namespace Midjourney.Infrastructure.LoadBalancer { @@ -202,7 +203,7 @@ public void ExitTask(TaskInfo task) /// 任务信息 /// Discord提交任务的委托 /// 任务提交结果 - public SubmitResultVO SubmitTask(TaskInfo info, Func> discordSubmit) + public SubmitResultVO SubmitTaskAsync(TaskInfo info, Func> discordSubmit) { _taskStoreService.Save(info); @@ -371,6 +372,17 @@ private void SaveAndNotify(TaskInfo task) public Task ActionAsync(string messageId, string customId, int messageFlags, string nonce) => _service.ActionAsync(messageId, customId, messageFlags, nonce); + /// + /// 执行 ZOOM + /// + /// + /// + /// + /// + /// + public Task ZoomAsync(string messageId, string customId, string prompt, string nonce) => + _service.ZoomAsync(messageId, customId, prompt, nonce); + /// /// 异步执行描述任务。 /// diff --git a/src/Midjourney.Infrastructure/LoadBalancer/IDiscordInstance.cs b/src/Midjourney.Infrastructure/LoadBalancer/IDiscordInstance.cs index d95c761a..3b9d95ed 100644 --- a/src/Midjourney.Infrastructure/LoadBalancer/IDiscordInstance.cs +++ b/src/Midjourney.Infrastructure/LoadBalancer/IDiscordInstance.cs @@ -56,7 +56,7 @@ public interface IDiscordInstance : IDiscordService /// 任务实例。 /// 提交操作。 /// 提交结果。 - SubmitResultVO SubmitTask(TaskInfo task, Func> discordSubmit); + SubmitResultVO SubmitTaskAsync(TaskInfo task, Func> discordSubmit); IEnumerable FindRunningTask(Func condition); diff --git a/src/Midjourney.Infrastructure/Midjourney.Infrastructure.csproj b/src/Midjourney.Infrastructure/Midjourney.Infrastructure.csproj index 143df317..ffbe2a04 100644 --- a/src/Midjourney.Infrastructure/Midjourney.Infrastructure.csproj +++ b/src/Midjourney.Infrastructure/Midjourney.Infrastructure.csproj @@ -30,10 +30,12 @@ + + diff --git a/src/Midjourney.Infrastructure/Resources/ApiParams/zoom.json b/src/Midjourney.Infrastructure/Resources/ApiParams/zoom.json new file mode 100644 index 00000000..21bbf31d --- /dev/null +++ b/src/Midjourney.Infrastructure/Resources/ApiParams/zoom.json @@ -0,0 +1,24 @@ +{ + "type": 5, + "application_id": "936929561302675456", + "channel_id": "$channel_id", + "guild_id": "$guild_id", + "data": { + "id": "$message_id", + "custom_id": "MJ::OutpaintCustomZoomModal::$message_hash", + "components": [ + { + "type": 1, + "components": [ + { + "type": 4, + "custom_id": "MJ::OutpaintCustomZoomModal::prompt", + "value": "$prompt" + } + ] + } + ] + }, + "session_id": "$session_id", + "nonce": "$nonce" +} \ No newline at end of file diff --git a/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs b/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs index 6857d57a..4a19f4be 100644 --- a/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs +++ b/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs @@ -122,6 +122,30 @@ public async Task ActionAsync(string messageId, string customId, int me return await PostJsonAndCheckStatusAsync(paramsStr); } + /// + /// 自定义变焦 + /// + /// + /// + /// + /// + /// + public async Task ZoomAsync(string messageId, string customId, string prompt, string nonce) + { + customId = customId.Replace("MJ::CustomZoom::", "MJ::OutpaintCustomZoomModal::"); + + string paramsStr = ReplaceInteractionParams(_paramsMap["zoom"], nonce) + .Replace("$message_id", messageId) + .Replace("$prompt", prompt); + + var obj = JObject.Parse(paramsStr); + + obj["data"]["custom_id"] = customId; + + paramsStr = obj.ToString(); + return await PostJsonAndCheckStatusAsync(paramsStr); + } + public async Task RerollAsync(string messageId, string messageHash, int messageFlags, string nonce) { string paramsStr = ReplaceInteractionParams(_paramsMap["reroll"], nonce) diff --git a/src/Midjourney.Infrastructure/Services/IDiscordService.cs b/src/Midjourney.Infrastructure/Services/IDiscordService.cs index 67e9c319..e4fd050b 100644 --- a/src/Midjourney.Infrastructure/Services/IDiscordService.cs +++ b/src/Midjourney.Infrastructure/Services/IDiscordService.cs @@ -55,6 +55,16 @@ public interface IDiscordService /// Task ActionAsync(string messageId, string customId, int messageFlags, string nonce); + /// + /// 执行 ZOOM + /// + /// + /// + /// + /// + /// + Task ZoomAsync(string messageId, string customId, string prompt, string nonce); + /// /// 提交描述任务。 /// diff --git a/src/Midjourney.Infrastructure/Services/ITaskService.cs b/src/Midjourney.Infrastructure/Services/ITaskService.cs index e62972c7..5752c808 100644 --- a/src/Midjourney.Infrastructure/Services/ITaskService.cs +++ b/src/Midjourney.Infrastructure/Services/ITaskService.cs @@ -71,5 +71,13 @@ public interface ITaskService /// /// SubmitResultVO SubmitAction(TaskInfo task, SubmitActionDTO submitAction); + + /// + /// 执行 Modal + /// + /// + /// + /// + SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction); } } \ No newline at end of file diff --git a/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs b/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs index b94a40ea..b610e1b0 100644 --- a/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs +++ b/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs @@ -2,6 +2,7 @@ using Midjourney.Infrastructure.LoadBalancer; using Midjourney.Infrastructure.Util; using Serilog; +using System.Diagnostics; namespace Midjourney.Infrastructure.Services { @@ -34,7 +35,7 @@ public SubmitResultVO SubmitImagine(TaskInfo info, List dataUrls) info.SetProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, instance.GetInstanceId()); - return instance.SubmitTask(info, async () => + return instance.SubmitTaskAsync(info, async () => { var imageUrls = new List(); foreach (var dataUrl in dataUrls) @@ -72,7 +73,7 @@ public SubmitResultVO SubmitUpscale(TaskInfo task, string targetMessageId, strin { return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); } - return discordInstance.SubmitTask(task, async () => + return discordInstance.SubmitTaskAsync(task, async () => await discordInstance.UpscaleAsync(targetMessageId, index, targetMessageHash, messageFlags, task.GetProperty(Constants.TASK_PROPERTY_NONCE, default))); } @@ -84,7 +85,7 @@ public SubmitResultVO SubmitVariation(TaskInfo task, string targetMessageId, str { return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); } - return discordInstance.SubmitTask(task, async () => + return discordInstance.SubmitTaskAsync(task, async () => await discordInstance.VariationAsync(targetMessageId, index, targetMessageHash, messageFlags, task.GetProperty(Constants.TASK_PROPERTY_NONCE, default))); } @@ -96,7 +97,7 @@ public SubmitResultVO SubmitReroll(TaskInfo task, string targetMessageId, string { return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); } - return discordInstance.SubmitTask(task, async () => + return discordInstance.SubmitTaskAsync(task, async () => await discordInstance.RerollAsync(targetMessageId, targetMessageHash, messageFlags, task.GetProperty(Constants.TASK_PROPERTY_NONCE, default))); } @@ -108,7 +109,7 @@ public SubmitResultVO SubmitDescribe(TaskInfo task, DataUrl dataUrl) return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); } task.SetProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.GetInstanceId()); - return discordInstance.SubmitTask(task, async () => + return discordInstance.SubmitTaskAsync(task, async () => { var taskFileName = $"{task.Id}.{MimeTypeUtils.GuessFileSuffix(dataUrl.MimeType)}"; var uploadResult = await discordInstance.UploadAsync(taskFileName, dataUrl); @@ -129,7 +130,7 @@ public SubmitResultVO SubmitBlend(TaskInfo task, List dataUrls, BlendDi return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); } task.SetProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.GetInstanceId()); - return discordInstance.SubmitTask(task, async () => + return discordInstance.SubmitTaskAsync(task, async () => { var finalFileNames = new List(); foreach (var dataUrl in dataUrls) @@ -146,7 +147,6 @@ public SubmitResultVO SubmitBlend(TaskInfo task, List dataUrls, BlendDi }); } - /// /// 执行动作 /// @@ -162,21 +162,94 @@ public SubmitResultVO SubmitAction(TaskInfo task, SubmitActionDTO submitAction) } task.SetProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.GetInstanceId()); - return discordInstance.SubmitTask(task, async () => + var targetTask = _taskStoreService.Get(submitAction.TaskId)!; + var messageFlags = targetTask.GetProperty(Constants.TASK_PROPERTY_FLAGS, default); + var messageId = targetTask.GetProperty(Constants.TASK_PROPERTY_MESSAGE_ID, default); + + task.BotType = targetTask.BotType; + task.SetProperty(Constants.TASK_PROPERTY_BOT_TYPE, targetTask.BotType); + task.SetProperty(Constants.TASK_PROPERTY_CUSTOM_ID, submitAction.CustomId); + + // 设置任务的提示信息 = 父级任务的提示信息 + task.Prompt = targetTask.Prompt; + task.PromptEn = targetTask.PromptEn; + + // 如果是 Modal 作业,则直接返回 + if (submitAction.CustomId.StartsWith("MJ::CustomZoom::")) { - var targetTask = _taskStoreService.Get(submitAction.TaskId)!; - var messageFlags = targetTask.GetProperty(Constants.TASK_PROPERTY_FLAGS, default); - var messageId = targetTask.GetProperty(Constants.TASK_PROPERTY_MESSAGE_ID, default); + task.SetProperty(Constants.TASK_PROPERTY_MESSAGE_ID, targetTask.MessageId); + task.SetProperty(Constants.TASK_PROPERTY_FLAGS, messageFlags); - // 设置任务的提示信息 = 父级任务的提示信息 - task.Prompt = targetTask.Prompt; - task.PromptEn = targetTask.PromptEn; - task.BotType = targetTask.BotType; - task.SetProperty(Constants.TASK_PROPERTY_BOT_TYPE, targetTask.BotType); + _taskStoreService.Save(task); + return SubmitResultVO.Of(ReturnCode.SUCCESS, "提交成功", task.Id); + } + return discordInstance.SubmitTaskAsync(task, async () => + { return await discordInstance.ActionAsync(messageId ?? targetTask.MessageId, submitAction.CustomId, messageFlags, task.GetProperty(Constants.TASK_PROPERTY_NONCE, default)); }); } + + /// + /// 执行 Modal + /// + /// + /// + /// + public SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction) + { + var discordInstance = _discordLoadBalancer.ChooseInstance(); + if (discordInstance == null) + { + return SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); + } + task.SetProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.GetInstanceId()); + + return discordInstance.SubmitTaskAsync(task, async () => + { + var messageFlags = task.GetProperty(Constants.TASK_PROPERTY_FLAGS, default); + var messageId = task.GetProperty(Constants.TASK_PROPERTY_MESSAGE_ID, default); + + var customId = task.GetProperty(Constants.TASK_PROPERTY_CUSTOM_ID, default); + var nonce = task.GetProperty(Constants.TASK_PROPERTY_NONCE, default); + + // 弹出,再执行变焦 + var res = await discordInstance.ActionAsync(messageId, customId, messageFlags, nonce); + if(res.Code != ReturnCode.SUCCESS) + { + return res; + } + + // 等待获取 messageId + // 等待最大超时 5min + var sw = new Stopwatch(); + sw.Start(); + + do + { + Thread.Sleep(500); + task = discordInstance.GetRunningTask(task.Id); + + if (string.IsNullOrWhiteSpace(task.MessageId)) + { + if (sw.ElapsedMilliseconds > 300000) + { + return Message.Of(ReturnCode.NOT_FOUND, "超时,未找到消息 ID"); + } + } + } while (string.IsNullOrWhiteSpace(task.MessageId)); + + nonce = SnowFlake.NextId(); + task.Nonce = nonce; + task.SetProperty(Constants.TASK_PROPERTY_NONCE, nonce); + + // 变焦 + return await discordInstance.ZoomAsync(task.MessageId, + customId, + task.PromptEn, + nonce); + }); + } } } \ No newline at end of file