From c613bcd9c2c15d38511233366ce4981acb630cec Mon Sep 17 00:00:00 2001 From: trueai-org Date: Wed, 10 Jul 2024 14:11:54 +0800 Subject: [PATCH] add zoom, add inpaint, add region --- README.md | 16 +++-- docs/discord-params.md | 72 +++++++++++++++++++ .../Controllers/SubmitController.cs | 32 +++++---- .../BotMessageListener.cs | 16 +++-- src/Midjourney.Infrastructure/Constants.cs | 5 ++ .../DiscordAccountHelper.cs | 4 +- src/Midjourney.Infrastructure/Enums.cs | 12 ++++ .../LoadBalancer/DiscordInstanceImpl.cs | 11 +++ .../Services/DiscordServiceImpl.cs | 48 ++++++++++++- .../Services/IDiscordService.cs | 10 +++ .../Services/ITaskService.cs | 2 +- .../Services/TaskServiceImpl.cs | 34 ++++++--- .../WebSocketHandler.cs | 8 ++- .../WebSocketStarter.cs | 9 ++- 14 files changed, 240 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 790cc0e7..2f2e1f16 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ Midjourney api 的 C# 版本。 ## 主要功能 +#### 绘画功能 + - [x] 支持 Imagine 指令和相关动作 [V1/V2.../U1/U2.../R] - [x] Imagine 时支持添加图片 base64,作为垫图 - [x] 支持 Blend (图片混合)、Describe (图生文) 指令 @@ -15,20 +17,22 @@ Midjourney api 的 C# 版本。 - [x] 支持中文 prompt 翻译,需配置百度翻译或 gpt - [x] prompt 敏感词预检测,支持覆盖调整 - [x] user-token 连接 wss,可以获取错误信息和完整功能 -- [x] 支持多账号配置,每个账号可设置对应的任务队列 - [x] 支持 Shorten(prompt分析) 指令 -- [x] 支持焦点移动: Pan ⬅️ ➡️ ⬆️ ⬇️ +- [x] 支持焦点移动: Pan ⬅️➡⬆️⬇️ - [x] 支持局部重绘: Vary (Region) 🖌 - [x] 支持几乎所有的关联按钮动作和 🎛️ Remix 模式 - [x] 账号池持久化,动态维护 -- [x] 支持获取账号 /info、/settings信 息 +- [x] 支持图片变焦,自定义变焦 Zoom 🔍 +- [ ] 支持获取图片的 seed 值 + +#### 账号管理 + +- [ ] 支持多账号配置,每个账号可设置对应的任务队列 +- [ ] 支持获取账号 /info、/settings信 息 - [ ] 账号 settings 设置 - [ ] 支持 niji bot 机器人 -- [ ] 支持获取图片的 seed 值 - [ ] 支持 InsightFace 人脸替换机器人 - [ ] 内嵌管理后台页面 -- [ ] 支持图片变焦: Zoom 🔍 -- [ ] 支持局部重绘 ## 配置项 - mj.accounts: 参考 [账号池配置](./docs/config.md#%E8%B4%A6%E5%8F%B7%E6%B1%A0%E9%85%8D%E7%BD%AE%E5%8F%82%E8%80%83) diff --git a/docs/discord-params.md b/docs/discord-params.md index b826cc18..120fdc61 100644 --- a/docs/discord-params.md +++ b/docs/discord-params.md @@ -9,3 +9,75 @@ 频道的url里取出 服务器ID、频道ID,后续设置到配置项 ![Guild Channel ID](img_9.png) + + +### 如何捕获事件:INTERACTION_IFRAME_MODAL_CREATE + +> 参考 + +https://github.com/dolfies/discord.py-self/discussions/573 + +> 注意修改 command.ts 为:https://discord.com/api/v9/users/@me/application-command-index + +https://github.com/bao-io/midjourney-sdk +https://www.npmjs.com/package/midjourney-sdk + +在处理 WebSocket 连接和会话管理时,捕获特定的前端事件,如 `INTERACTION_IFRAME_MODAL_CREATE`,可以是一个挑战。以下教程基于实际对话,解释了如何正确捕获这一事件。 + + +#### 环境设定 +- 技术栈:TypeScript +- 应用场景:WebSocket 连接管理 + +#### 步骤概述 + +1. **创建 WebSocket 连接**: + 确保在创建 WebSocket 连接时捕获并保存 `session_id`。这个 `session_id` 是后续所有交互的关键。 + +2. **捕获 `session_id`**: + 在建立 WebSocket 连接时,通常会从服务器接收到一个类型为 `READY` 的消息,该消息包含了 `session_id`。 + +3. **使用 `session_id` 发送请求**: + 在发送请求以创建交互式 iframe 模态框时,必须使用从 WebSocket 连接中获得的 `session_id`。 + +4. **处理和监听事件**: + 使用该 `session_id` 发送数据后,系统应能够正确触发 `INTERACTION_IFRAME_MODAL_CREATE` 事件。 + +#### 代码示例 + +```typescript +// 假设 websocket 已经连接并且是可用的状态 +websocket.onmessage = function(event) { + let data = JSON.parse(event.data); + let type = data.type; + + // 当服务器发送 READY 类型的消息时,保存 session_id + if (type === 'READY') { + this.opts.session_id = data.session_id; + } +}; + +// 使用保存的 session_id 发送请求 +function sendInteractionRequest() { + let request = { + type: 'INTERACTION_IFRAME_MODAL_CREATE', + session_id: this.opts.session_id + }; + websocket.send(JSON.stringify(request)); +} + +// 监听事件 +websocket.onmessage = function(event) { + let data = JSON.parse(event.data); + if (data.type === 'INTERACTION_IFRAME_MODAL_CREATE') { + console.log('Modal create interaction triggered successfully.'); + } +}; +``` + +#### 常见问题解决 + +- **问题**: `INTERACTION_IFRAME_MODAL_CREATE` 事件不触发。 + - **解决方案**: 确保发送的 `session_id` 是在建立 WebSocket 连接时接收的那个。如果 `session_id` 错误,该事件不会被正确触发。 + +通过以上步骤和示例代码,您应该能够有效地捕获并处理 `INTERACTION_IFRAME_MODAL_CREATE` 事件。这需要确保您正确管理和使用 `session_id`,以保证数据的一致性和事件的正确触发。 \ No newline at end of file diff --git a/src/Midjourney.API/Controllers/SubmitController.cs b/src/Midjourney.API/Controllers/SubmitController.cs index 0624621a..ac25ed09 100644 --- a/src/Midjourney.API/Controllers/SubmitController.cs +++ b/src/Midjourney.API/Controllers/SubmitController.cs @@ -3,6 +3,7 @@ using Midjourney.Infrastructure.Dto; using Midjourney.Infrastructure.Services; using Midjourney.Infrastructure.Util; +using System.Buffers.Text; using System.Text.RegularExpressions; using TaskStatus = Midjourney.Infrastructure.TaskStatus; @@ -329,6 +330,12 @@ public ActionResult Action([FromBody] SubmitActionDTO actionDTO) task.Action = TaskAction.ACTION; task.Description = "Waiting for window confirm"; } + // 局部绘制 + // MJ::Inpaint::1::da2b1fda-0455-4952-9f0e-d4cb891f8b1e::SOLO + else if (actionDTO.CustomId.StartsWith("MJ::Inpaint::")) + { + task.Action = TaskAction.ACTION; + } else { task.Action = TaskAction.ACTION; @@ -377,22 +384,21 @@ public ActionResult Modal([FromBody] SubmitModalDTO actionDTO) .SetProperty("bannedWord", e.Message)); } - //List base64Array = imagineDTO.Base64Array ?? new List(); - - //List dataUrls = new List(); - //try - //{ - // dataUrls = ConvertUtils.ConvertBase64Array(base64Array); - //} - //catch (Exception e) - //{ - // _logger.LogError(e, "base64格式转换异常"); - // return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "base64格式错误")); - //} + // 不检查 + DataUrl dataUrl = null; + try + { + //dataUrl = DataUrl.Parse(actionDTO.MaskBase64); + } + catch (Exception e) + { + _logger.LogError(e, "base64格式转换异常"); + return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "base64格式错误")); + } task.PromptEn = promptEn; - return Ok(_taskService.SubmitModal(task, actionDTO)); + return Ok(_taskService.SubmitModal(task, actionDTO, dataUrl)); } diff --git a/src/Midjourney.Infrastructure/BotMessageListener.cs b/src/Midjourney.Infrastructure/BotMessageListener.cs index 626f6645..045e238a 100644 --- a/src/Midjourney.Infrastructure/BotMessageListener.cs +++ b/src/Midjourney.Infrastructure/BotMessageListener.cs @@ -79,8 +79,6 @@ public async Task StartAsync() // Subscribe a handler to see if a message invokes a command. _client.MessageReceived += HandleCommandAsync; _client.MessageUpdated += MessageUpdatedAsync; - - //_client.InteractionCreated += HandleInteractionAsync; } private DiscordSocketClient _client; @@ -236,7 +234,6 @@ public void OnMessage(JsonElement raw) return; } - // 如果有渠道 id,但不是当前渠道 id,则忽略 if (data.TryGetProperty("channel_id", out JsonElement channelIdElement) && channelIdElement.GetString() != _discordAccount.ChannelId) { @@ -271,7 +268,6 @@ public void OnMessage(JsonElement raw) var id = idElement.GetString(); _logger.Debug($"用户消息, {messageType}, {_discordAccount.GetDisplay()} - {authorName}: {contentStr}, id: {id}, mid: {metaId}"); - // 判断账号是否用量已经用完 if (messageType == MessageType.CREATE && data.TryGetProperty("embeds", out var em)) { @@ -320,6 +316,18 @@ public void OnMessage(JsonElement raw) { task.InteractionMetadataId = id; } + // MJ 局部重绘完成后 + else if (messageType == MessageType.INTERACTION_IFRAME_MODAL_CREATE + && data.TryGetProperty("custom_id", out var custom_id)) + { + task.SetProperty(Constants.TASK_PROPERTY_IFRAME_MODAL_CREATE_CUSTOM_ID, custom_id.GetString()); + task.MessageId = id; + + if (!task.MessageIds.Contains(id)) + { + task.MessageIds.Add(id); + } + } else { task.MessageId = id; diff --git a/src/Midjourney.Infrastructure/Constants.cs b/src/Midjourney.Infrastructure/Constants.cs index ae680b83..288ec72a 100644 --- a/src/Midjourney.Infrastructure/Constants.cs +++ b/src/Midjourney.Infrastructure/Constants.cs @@ -68,6 +68,11 @@ public static class Constants /// public const string TASK_PROPERTY_REFERENCED_MESSAGE_ID = "referencedMessageId"; + /// + /// 局部重绘弹窗 custom_id + /// + public const string TASK_PROPERTY_IFRAME_MODAL_CREATE_CUSTOM_ID = "iframe_modal_custom_id"; + // 任务扩展属性 end /// diff --git a/src/Midjourney.Infrastructure/DiscordAccountHelper.cs b/src/Midjourney.Infrastructure/DiscordAccountHelper.cs index 0545697a..8bc02ec1 100644 --- a/src/Midjourney.Infrastructure/DiscordAccountHelper.cs +++ b/src/Midjourney.Infrastructure/DiscordAccountHelper.cs @@ -103,7 +103,9 @@ public async Task CreateDiscordInstance(DiscordAccount account _messageListener = new BotMessageListener(account.BotToken, account, webProxy); // 用户 WebSocket 连接 - var webSocket = new WebSocketStarter(account, _discordHelper, _messageListener, webProxy); + var webSocket = new WebSocketStarter(account, _discordHelper, _messageListener, + webProxy, discordService); + await webSocket.StartAsync(); _messageListener.Init(discordInstance, _messageHandlers); diff --git a/src/Midjourney.Infrastructure/Enums.cs b/src/Midjourney.Infrastructure/Enums.cs index d785a043..38f6fa26 100644 --- a/src/Midjourney.Infrastructure/Enums.cs +++ b/src/Midjourney.Infrastructure/Enums.cs @@ -153,6 +153,16 @@ public enum MessageType /// /// INTERACTION_SUCCESS, + + /// + /// + /// + INTERACTION_IFRAME_MODAL_CREATE, + + /// + /// + /// + INTERACTION_MODAL_CREATE } public static class MessageTypeExtensions @@ -171,6 +181,8 @@ public static class MessageTypeExtensions "MESSAGE_DELETE" => MessageType.DELETE, "INTERACTION_CREATE" => MessageType.INTERACTION_CREATE, "INTERACTION_SUCCESS" => MessageType.INTERACTION_SUCCESS, + "INTERACTION_IFRAME_MODAL_CREATE" => MessageType.INTERACTION_IFRAME_MODAL_CREATE, + "INTERACTION_MODAL_CREATE" => MessageType.INTERACTION_MODAL_CREATE, _ => null }; } diff --git a/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs b/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs index 60e2c121..71b8fdef 100644 --- a/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs +++ b/src/Midjourney.Infrastructure/LoadBalancer/DiscordInstanceImpl.cs @@ -383,6 +383,17 @@ public Task ActionAsync(string messageId, string customId, int messageF public Task ZoomAsync(string messageId, string customId, string prompt, string nonce) => _service.ZoomAsync(messageId, customId, prompt, nonce); + + /// + /// 局部重绘 + /// + /// + /// + /// + /// + public Task InpaintAsync(string customId, string prompt, string maskBase64) => + _service.InpaintAsync(customId, prompt, maskBase64); + /// /// 异步执行描述任务。 /// diff --git a/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs b/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs index 4a19f4be..88964dcf 100644 --- a/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs +++ b/src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs @@ -11,7 +11,6 @@ namespace Midjourney.Infrastructure.Services /// public class DiscordServiceImpl : IDiscordService { - private static readonly string DefaultSessionId = "f1a313a09ce079ce252459dc70231f30"; private readonly DiscordAccount _account; private readonly HttpClient _httpClient; private readonly DiscordHelper _discordHelper; @@ -42,6 +41,11 @@ public DiscordServiceImpl(DiscordAccount account, _discordMessageUrl = $"{discordServer}/api/v9/channels/{account.ChannelId}/messages"; } + /// + /// 默认会话ID。 + /// + public string DefaultSessionId { get; set; } = "f1a313a09ce079ce252459dc70231f30"; + public async Task ImagineAsync(string prompt, string nonce) { string paramsStr = ReplaceInteractionParams(_paramsMap["imagine"], nonce); @@ -146,6 +150,48 @@ public async Task ZoomAsync(string messageId, string customId, string p return await PostJsonAndCheckStatusAsync(paramsStr); } + /// + /// 局部重绘 + /// + /// + /// + /// + /// + public async Task InpaintAsync(string customId, string prompt, string maskBase64) + { + try + { + customId = customId.Replace("MJ::iframe::", ""); + + // mask.replace(/^data:.+?;base64,/, ''), + maskBase64 = maskBase64.Replace("data:image/png;base64,", ""); + + var obj = new + { + customId = customId, + //full_prompt = null, + mask = maskBase64, + prompt = prompt, + userId = "0", + username = "0", + }; + var paramsStr = Newtonsoft.Json.JsonConvert.SerializeObject(obj); + var response = await PostJsonAsync("https://936929561302675456.discordsays.com/inpaint/api/submit-job", + paramsStr); + + if (response.StatusCode == System.Net.HttpStatusCode.OK) + { + return Message.Success(); + } + + return Message.Of((int)response.StatusCode, "提交失败"); + } + catch (HttpRequestException e) + { + return ConvertHttpRequestException(e); + } + } + public async Task RerollAsync(string messageId, string messageHash, int messageFlags, string nonce) { string paramsStr = ReplaceInteractionParams(_paramsMap["reroll"], nonce) diff --git a/src/Midjourney.Infrastructure/Services/IDiscordService.cs b/src/Midjourney.Infrastructure/Services/IDiscordService.cs index e4fd050b..c947e842 100644 --- a/src/Midjourney.Infrastructure/Services/IDiscordService.cs +++ b/src/Midjourney.Infrastructure/Services/IDiscordService.cs @@ -65,6 +65,16 @@ public interface IDiscordService /// Task ZoomAsync(string messageId, string customId, string prompt, string nonce); + + /// + /// 局部重绘 + /// + /// + /// + /// + /// + Task InpaintAsync(string customId, string prompt, string maskBase64); + /// /// 提交描述任务。 /// diff --git a/src/Midjourney.Infrastructure/Services/ITaskService.cs b/src/Midjourney.Infrastructure/Services/ITaskService.cs index 5752c808..c2d7cd61 100644 --- a/src/Midjourney.Infrastructure/Services/ITaskService.cs +++ b/src/Midjourney.Infrastructure/Services/ITaskService.cs @@ -78,6 +78,6 @@ public interface ITaskService /// /// /// - SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction); + SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction, DataUrl dataUrl = null); } } \ No newline at end of file diff --git a/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs b/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs index b610e1b0..0cdef89b 100644 --- a/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs +++ b/src/Midjourney.Infrastructure/Services/TaskServiceImpl.cs @@ -175,7 +175,8 @@ public SubmitResultVO SubmitAction(TaskInfo task, SubmitActionDTO submitAction) task.PromptEn = targetTask.PromptEn; // 如果是 Modal 作业,则直接返回 - if (submitAction.CustomId.StartsWith("MJ::CustomZoom::")) + if (submitAction.CustomId.StartsWith("MJ::CustomZoom::") + || submitAction.CustomId.StartsWith("MJ::Inpaint::")) { task.SetProperty(Constants.TASK_PROPERTY_MESSAGE_ID, targetTask.MessageId); task.SetProperty(Constants.TASK_PROPERTY_FLAGS, messageFlags); @@ -197,7 +198,7 @@ public SubmitResultVO SubmitAction(TaskInfo task, SubmitActionDTO submitAction) /// /// /// - public SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction) + public SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction, DataUrl dataUrl = null) { var discordInstance = _discordLoadBalancer.ChooseInstance(); if (discordInstance == null) @@ -216,7 +217,7 @@ public SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction) // 弹出,再执行变焦 var res = await discordInstance.ActionAsync(messageId, customId, messageFlags, nonce); - if(res.Code != ReturnCode.SUCCESS) + if (res.Code != ReturnCode.SUCCESS) { return res; } @@ -240,15 +241,26 @@ public SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction) } } while (string.IsNullOrWhiteSpace(task.MessageId)); - nonce = SnowFlake.NextId(); - task.Nonce = nonce; - task.SetProperty(Constants.TASK_PROPERTY_NONCE, nonce); + // 自定义变焦 + if (customId.StartsWith("MJ::CustomZoom::")) + { + nonce = SnowFlake.NextId(); + task.Nonce = nonce; + task.SetProperty(Constants.TASK_PROPERTY_NONCE, nonce); - // 变焦 - return await discordInstance.ZoomAsync(task.MessageId, - customId, - task.PromptEn, - nonce); + return await discordInstance.ZoomAsync(task.MessageId, customId, task.PromptEn, nonce); + } + // 局部重绘 + else if (customId.StartsWith("MJ::Inpaint::")) + { + var ifarmeCustomId = task.GetProperty(Constants.TASK_PROPERTY_IFRAME_MODAL_CREATE_CUSTOM_ID, default); + return await discordInstance.InpaintAsync(ifarmeCustomId, task.PromptEn, submitAction.MaskBase64); + } + else + { + // 不支持 + return Message.Of(ReturnCode.NOT_FOUND, "不支持的操作"); + } }); } } diff --git a/src/Midjourney.Infrastructure/WebSocketHandler.cs b/src/Midjourney.Infrastructure/WebSocketHandler.cs index 9561a441..8ab8751d 100644 --- a/src/Midjourney.Infrastructure/WebSocketHandler.cs +++ b/src/Midjourney.Infrastructure/WebSocketHandler.cs @@ -1,4 +1,5 @@ using Midjourney.Infrastructure.Domain; +using Midjourney.Infrastructure.Services; using Serilog; using System.IO.Compression; using System.Net; @@ -268,6 +269,10 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) _logger.Information("用户 WebSocket 连接已关闭。"); HandleFailure((int)result.CloseStatus, result.CloseStatusDescription); } + else + { + _logger.Warning("用户收到未知消息"); + } } catch (Exception ex) { @@ -284,7 +289,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) /// 接收到的消息内容 private void HandleMessage(string message) { - //_logger.Information("用户 收到消息: {0}", message); + //_logger.Debug("用户 收到消息: {0}", message); var data = JsonDocument.Parse(message).RootElement; var opCode = data.GetProperty("op").GetInt32(); @@ -381,6 +386,7 @@ private void HandleDispatch(JsonElement data) { _sessionId = data.GetProperty("d").GetProperty("session_id").GetString(); _resumeGatewayUrl = data.GetProperty("d").GetProperty("resume_gateway_url").GetString() + "/?encoding=json&v=9&compress=zlib-stream"; + OnSuccess(); } else if (data.TryGetProperty("t", out var resumed) && resumed.GetString() == "RESUMED") diff --git a/src/Midjourney.Infrastructure/WebSocketStarter.cs b/src/Midjourney.Infrastructure/WebSocketStarter.cs index 62673fea..8509a988 100644 --- a/src/Midjourney.Infrastructure/WebSocketStarter.cs +++ b/src/Midjourney.Infrastructure/WebSocketStarter.cs @@ -1,4 +1,5 @@ using Midjourney.Infrastructure.Domain; +using Midjourney.Infrastructure.Services; using Midjourney.Infrastructure.Util; using Serilog; using System.Net; @@ -21,18 +22,21 @@ public class WebSocketStarter private ClientWebSocket _webSocketSession = null; private ResumeData _resumeData = null; + private DiscordServiceImpl _discordService; public WebSocketStarter( DiscordAccount account, DiscordHelper discordHelper, BotMessageListener userMessageListener, - WebProxy webProxy) + WebProxy webProxy, + DiscordServiceImpl discordService) { _account = account; _userMessageListener = userMessageListener; _discordHelper = discordHelper; _logger = Log.Logger; _webProxy = webProxy; + _discordService = discordService; } public async Task StartAsync() @@ -67,6 +71,7 @@ private void OnSocketSuccess(string sessionId, object sequence, string resumeGat { _resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl); _running = true; + _discordService.DefaultSessionId = sessionId; NotifyWssLock(ReturnCode.SUCCESS, ""); } @@ -212,6 +217,8 @@ public ResumeData(string sessionId, object sequence, string resumeGatewayUrl) SessionId = sessionId; Sequence = sequence; ResumeGatewayUrl = resumeGatewayUrl; + + } } }