Skip to content

Commit

Permalink
add modal zoom action
Browse files Browse the repository at this point in the history
  • Loading branch information
trueai-org committed Jul 9, 2024
1 parent d807b51 commit 81cf627
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 22 deletions.
66 changes: 63 additions & 3 deletions src/Midjourney.API/Controllers/SubmitController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ public ActionResult<SubmitResultVO> Action([FromBody] SubmitActionDTO actionDTO)
{
task.Action = TaskAction.ACTION;
}
// 自定义变焦
// "MJ::CustomZoom::439f8670-52e8-4f57-afaa-fa08f6d6c751"
else if (actionDTO.CustomId.StartsWith("MJ::CustomZoom::"))
{
task.Action = TaskAction.ACTION;
task.Description = "Waiting for window confirm";
}
else
{
task.Action = TaskAction.ACTION;
Expand All @@ -336,6 +343,59 @@ public ActionResult<SubmitResultVO> Action([FromBody] SubmitActionDTO actionDTO)
return Ok(_taskService.SubmitAction(task, actionDTO));
}

/// <summary>
/// 提交 Modal
/// </summary>
/// <param name="actionDTO"></param>
/// <returns></returns>
[HttpPost("modal")]
public ActionResult<SubmitResultVO> Modal([FromBody] SubmitModalDTO actionDTO)
{
if (string.IsNullOrWhiteSpace(actionDTO.TaskId) || string.IsNullOrWhiteSpace(actionDTO.Prompt))
{
return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "参数错误"));
}

var targetTask = _taskStoreService.Get(actionDTO.TaskId);
if (targetTask == null)
{
return NotFound(SubmitResultVO.Fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效"));
}

var prompt = actionDTO.Prompt;
var task = targetTask;

var promptEn = TranslatePrompt(prompt);
try
{
BannedPromptUtils.CheckBanned(promptEn);
}
catch (BannedPromptException e)
{
return BadRequest(SubmitResultVO.Fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
.SetProperty("promptEn", promptEn)
.SetProperty("bannedWord", e.Message));
}

//List<string> base64Array = imagineDTO.Base64Array ?? new List<string>();

//List<DataUrl> dataUrls = new List<DataUrl>();
//try
//{
// dataUrls = ConvertUtils.ConvertBase64Array(base64Array);
//}
//catch (Exception e)
//{
// _logger.LogError(e, "base64格式转换异常");
// return BadRequest(SubmitResultVO.Fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"));
//}

task.PromptEn = promptEn;

return Ok(_taskService.SubmitModal(task, actionDTO));
}


/// <summary>
/// 创建新的任务对象
/// </summary>
Expand All @@ -353,10 +413,10 @@ private TaskInfo NewTask(BaseSubmitDTO baseDTO)
var notifyHook = string.IsNullOrWhiteSpace(baseDTO.NotifyHook) ? _properties.NotifyHook : baseDTO.NotifyHook;
task.SetProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);

var none = SnowFlake.NextId();
task.Nonce = none;
var nonce = SnowFlake.NextId();
task.Nonce = nonce;

task.SetProperty(Constants.TASK_PROPERTY_NONCE, none);
task.SetProperty(Constants.TASK_PROPERTY_NONCE, nonce);

return task;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Midjourney.Infrastructure/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public static class Constants
/// </summary>
public const string TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId";

/// <summary>
/// 执行动作 custom_id
/// </summary>
public const string TASK_PROPERTY_CUSTOM_ID = "custom_id";

/// <summary>
/// 标志.
/// </summary>
Expand Down
29 changes: 29 additions & 0 deletions src/Midjourney.Infrastructure/Dto/SubmitModalDTO.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Swashbuckle.AspNetCore.Annotations;

namespace Midjourney.Infrastructure.Dto
{
/// <summary>
/// Imagine提交参数。
/// </summary>
[SwaggerSchema("Imagine提交参数")]
public class SubmitModalDTO : BaseSubmitDTO
{
/// <summary>
/// 提示词。
/// </summary>
[SwaggerSchema("提示词", Description = "Cat")]
public string Prompt { get; set; }

/// <summary>
/// 任务ID。
/// </summary>
[SwaggerSchema("任务ID", Description = "\"1320098173412546\"")]
public string TaskId { get; set; }

/// <summary>
/// 局部重绘的蒙版base64
/// </summary>
[SwaggerSchema("图片base64", Description = "data:image/png;base64,xxx")]
public string MaskBase64 { get; set; }
}
}
5 changes: 4 additions & 1 deletion src/Midjourney.Infrastructure/Enums.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ public enum TaskAction
/// </summary>
OUTPAINT,

//ZOOM
///// <summary>
///// 自定义变焦
///// </summary>
//ZOOM,
//SHORTEN
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Midjourney.Infrastructure.Util;
using Serilog;
using System.Collections.Concurrent;
using System.Threading.Tasks;

namespace Midjourney.Infrastructure.LoadBalancer
{
Expand Down Expand Up @@ -202,7 +203,7 @@ public void ExitTask(TaskInfo task)
/// <param name="info">任务信息</param>
/// <param name="discordSubmit">Discord提交任务的委托</param>
/// <returns>任务提交结果</returns>
public SubmitResultVO SubmitTask(TaskInfo info, Func<Task<Message>> discordSubmit)
public SubmitResultVO SubmitTaskAsync(TaskInfo info, Func<Task<Message>> discordSubmit)
{
_taskStoreService.Save(info);

Expand Down Expand Up @@ -371,6 +372,17 @@ private void SaveAndNotify(TaskInfo task)
public Task<Message> ActionAsync(string messageId, string customId, int messageFlags, string nonce) =>
_service.ActionAsync(messageId, customId, messageFlags, nonce);

/// <summary>
/// 执行 ZOOM
/// </summary>
/// <param name="messageId"></param>
/// <param name="customId"></param>
/// <param name="prompt"></param>
/// <param name="nonce"></param>
/// <returns></returns>
public Task<Message> ZoomAsync(string messageId, string customId, string prompt, string nonce) =>
_service.ZoomAsync(messageId, customId, prompt, nonce);

/// <summary>
/// 异步执行描述任务。
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public interface IDiscordInstance : IDiscordService
/// <param name="task">任务实例。</param>
/// <param name="discordSubmit">提交操作。</param>
/// <returns>提交结果。</returns>
SubmitResultVO SubmitTask(TaskInfo task, Func<Task<Message>> discordSubmit);
SubmitResultVO SubmitTaskAsync(TaskInfo task, Func<Task<Message>> discordSubmit);

IEnumerable<TaskInfo> FindRunningTask(Func<TaskInfo, bool> condition);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
<Content Remove="Resources\ApiParams\upscale.json" />
<Content Remove="Resources\ApiParams\variation.json" />
<Content Remove="Resources\ApiParams\action.json" />
<Content Remove="Resources\ApiParams\zoom.json" />
<Content Remove="Resources\mime.types" />
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Resources\ApiParams\zoom.json" />
<EmbeddedResource Include="Resources\ApiParams\reroll.json" />
<EmbeddedResource Include="Resources\ApiParams\upscale.json" />
<EmbeddedResource Include="Resources\ApiParams\action.json" />
Expand Down
24 changes: 24 additions & 0 deletions src/Midjourney.Infrastructure/Resources/ApiParams/zoom.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"type": 5,
"application_id": "936929561302675456",
"channel_id": "$channel_id",
"guild_id": "$guild_id",
"data": {
"id": "$message_id",
"custom_id": "MJ::OutpaintCustomZoomModal::$message_hash",
"components": [
{
"type": 1,
"components": [
{
"type": 4,
"custom_id": "MJ::OutpaintCustomZoomModal::prompt",
"value": "$prompt"
}
]
}
]
},
"session_id": "$session_id",
"nonce": "$nonce"
}
24 changes: 24 additions & 0 deletions src/Midjourney.Infrastructure/Services/DiscordServiceImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,30 @@ public async Task<Message> ActionAsync(string messageId, string customId, int me
return await PostJsonAndCheckStatusAsync(paramsStr);
}

/// <summary>
/// 自定义变焦
/// </summary>
/// <param name="messageId"></param>
/// <param name="customId"></param>
/// <param name="prompt"></param>
/// <param name="nonce"></param>
/// <returns></returns>
public async Task<Message> ZoomAsync(string messageId, string customId, string prompt, string nonce)
{
customId = customId.Replace("MJ::CustomZoom::", "MJ::OutpaintCustomZoomModal::");

string paramsStr = ReplaceInteractionParams(_paramsMap["zoom"], nonce)
.Replace("$message_id", messageId)
.Replace("$prompt", prompt);

var obj = JObject.Parse(paramsStr);

obj["data"]["custom_id"] = customId;

paramsStr = obj.ToString();
return await PostJsonAndCheckStatusAsync(paramsStr);
}

public async Task<Message> RerollAsync(string messageId, string messageHash, int messageFlags, string nonce)
{
string paramsStr = ReplaceInteractionParams(_paramsMap["reroll"], nonce)
Expand Down
10 changes: 10 additions & 0 deletions src/Midjourney.Infrastructure/Services/IDiscordService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ public interface IDiscordService
/// <returns></returns>
Task<Message> ActionAsync(string messageId, string customId, int messageFlags, string nonce);

/// <summary>
/// 执行 ZOOM
/// </summary>
/// <param name="messageId"></param>
/// <param name="customId"></param>
/// <param name="prompt"></param>
/// <param name="nonce"></param>
/// <returns></returns>
Task<Message> ZoomAsync(string messageId, string customId, string prompt, string nonce);

/// <summary>
/// 提交描述任务。
/// </summary>
Expand Down
8 changes: 8 additions & 0 deletions src/Midjourney.Infrastructure/Services/ITaskService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,13 @@ public interface ITaskService
/// <param name="submitAction"></param>
/// <returns></returns>
SubmitResultVO SubmitAction(TaskInfo task, SubmitActionDTO submitAction);

/// <summary>
/// 执行 Modal
/// </summary>
/// <param name="task"></param>
/// <param name="submitAction"></param>
/// <returns></returns>
SubmitResultVO SubmitModal(TaskInfo task, SubmitModalDTO submitAction);
}
}
Loading

0 comments on commit 81cf627

Please sign in to comment.