Skip to content

Commit

Permalink
fix: 优化 mj 获取进度
Browse files Browse the repository at this point in the history
  • Loading branch information
xyfacai committed Dec 23, 2023
1 parent 7c4719b commit fd4ef08
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 11 deletions.
197 changes: 187 additions & 10 deletions controller/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import (
"time"
)

func UpdateMidjourneyTask() {
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
Expand All @@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
log.Printf("检测到未完成的任务数有: %v", len(tasks))
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
log.Printf("未完成的任务信息: %v", task)
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
log.Printf("requestUrl: %s", requestUrl)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
Expand Down Expand Up @@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
log.Println(task.MjId + " 构建失败," + task.FailReason)
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
Expand All @@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, 1, logContent)
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
Expand All @@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
}
}
}
*/

func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)

tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}

common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
for _, task := range tasks {
if task.MjId == "" {
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(taskChannelM) == 0 {
continue
}

for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)

body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()

for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}

task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
}
}

func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}

return false
}

func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTask()
go controller.UpdateMidjourneyTaskBulk()
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
Expand Down
18 changes: 18 additions & 0 deletions model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}

var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex

func InitChannelCache() {
Expand All @@ -149,10 +150,12 @@ func InitChannelCache() {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newChannelsIDM := make(map[int]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
}
for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
Expand All @@ -177,6 +180,7 @@ func InitChannelCache() {

channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
Expand Down Expand Up @@ -217,3 +221,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
idx := rand.Intn(endIdx)
return channels[idx], nil
}

func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()

c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
}
return c, nil
}
6 changes: 6 additions & 0 deletions model/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error {
err = DB.Save(midjourney).Error
return err
}

func MjBulkUpdate(taskIDs []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Updates(params).Error
}

0 comments on commit fd4ef08

Please sign in to comment.