From 728d3094436a28954e79c00968add8c01763ab3b Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 27 Jan 2025 14:22:20 +0800 Subject: [PATCH] feat: refactor artwork query handling and improve image source retrieval --- telegram/handlers/query_artwork.go | 140 ++++++++++++++++------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/telegram/handlers/query_artwork.go b/telegram/handlers/query_artwork.go index b6f6d0f..cc8c6ab 100644 --- a/telegram/handlers/query_artwork.go +++ b/telegram/handlers/query_artwork.go @@ -15,7 +15,6 @@ import ( "github.com/krau/ManyACG/config" "github.com/krau/ManyACG/service" - "github.com/krau/ManyACG/sources" "github.com/krau/ManyACG/telegram/utils" "github.com/krau/ManyACG/types" @@ -54,10 +53,7 @@ func RandomPicture(ctx context.Context, bot *telego.Bot, message telego.Message) if picture.TelegramInfo.PhotoFileID != "" { file = telegoutil.FileFromID(picture.TelegramInfo.PhotoFileID) } else { - photoURL := picture.Original - if artwork[0].SourceType == types.SourceTypePixiv { - photoURL = sources.GetPixivRegularURL(photoURL) - } + photoURL := fmt.Sprintf("%s/?url=%s&w=2560&h=2560&we&output=jpg", config.Cfg.WSRVURL, picture.Original) file = telegoutil.FileFromURL(photoURL) } caption := fmt.Sprintf("[%s](%s)", common.EscapeMarkdown(artwork[0].Title), artwork[0].SourceURL) @@ -119,29 +115,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M if len(artworks) > 10 { artworks = slice.Shuffle(artworks)[:10] } - - inputMedias := make([]telego.InputMedia, 0, len(artworks)) - for _, artwork := range artworks { - picture := artwork.Pictures[0] - var file telego.InputFile - if picture.TelegramInfo != nil && picture.TelegramInfo.PhotoFileID != "" { - file = telegoutil.FileFromID(picture.TelegramInfo.PhotoFileID) - } else { - photoURL := fmt.Sprintf("%s/?url=%s&w=2560&h=2560&we&output=jpg", config.Cfg.WSRVURL, picture.Original) - file = telegoutil.FileFromURL(photoURL) - } - caption := fmt.Sprintf("%s", artwork.SourceURL, common.EscapeHTML(artwork.Title)) - inputMedias = append(inputMedias, telegoutil.MediaPhoto(file).WithCaption(caption).WithParseMode(telego.ModeHTML)) - } - mediaGroup := telegoutil.MediaGroup(message.Chat.ChatID(), inputMedias...).WithReplyParameters(&telego.ReplyParameters{ - MessageID: message.MessageID, - ChatID: message.Chat.ChatID(), - }) - _, err = bot.SendMediaGroup(mediaGroup) - if err != nil { - common.Logger.Errorf("发送图片失败: %s", err) - } - + handleSendResultArtworks(artworks, message, bot) } func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego.Message) { @@ -156,43 +130,33 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego. var sourceURL string sourceURL = utils.FindSourceURLForMessage(message.ReplyToMessage) if sourceURL == "" { - if message.ReplyToMessage.Photo != nil || message.ReplyToMessage.Document != nil { - handleGetSourceURLFromPicture := func() (string, error) { - file, err := utils.GetMessagePhotoFile(bot, message.ReplyToMessage) - if err != nil { - return "", err - } - hash, err := common.GetImagePhashFromReader(bytes.NewReader(file)) - if err != nil { - return "", err - } - pictures, err := service.GetPicturesByHashHammingDistance(ctx, hash, 10) - if err != nil { - return "", err - } - if len(pictures) == 0 { - return "", errors.New("not found similar pictures by hash") - } - picture := pictures[0] - artworkID, err := primitive.ObjectIDFromHex(picture.ArtworkID) - if err != nil { - return "", err - } - artwork, err := service.GetArtworkByID(ctx, artworkID) - if err != nil { - return "", err - } - return artwork.SourceURL, nil - } - var err error - sourceURL, err = handleGetSourceURLFromPicture() - if err != nil { - common.Logger.Warnf("获取图片链接失败: %s", err) + if message.ReplyToMessage.Photo == nil && message.ReplyToMessage.Document == nil { + utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接") + return + } + var err error + var file []byte + sourceURL, file, err = handleGetSourceURLFromPicture(ctx, bot, message) + if err != nil { + common.Logger.Warnf("获取图片链接失败: %s", err) + if file == nil || common.TaggerClient == nil { utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接或图片") return } - } else { - utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接") + result, err := common.TaggerClient.Predict(ctx, file) + if err != nil || len(result.PredictedTags) == 0 { + common.Logger.Errorf("图片识别失败: %s", err) + utils.ReplyMessage(bot, message, "图片识别失败") + return + } + queryText := strings.Join(result.PredictedTags, ",") + artworks, err := service.HybridSearchArtworks(ctx, queryText, 0.8, 0, 10) + if err != nil || len(artworks) == 0 { + common.Logger.Errorf("搜索失败: %s", err) + utils.ReplyMessage(bot, message, "搜索失败") + return + } + handleSendResultArtworks(artworks, message, bot) return } } @@ -257,3 +221,55 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego. common.Logger.Errorf("发送图片失败: %s", err) } } + +func handleGetSourceURLFromPicture(ctx context.Context, bot *telego.Bot, message telego.Message) (string, []byte, error) { + file, err := utils.GetMessagePhotoFile(bot, message.ReplyToMessage) + if err != nil { + return "", nil, err + } + hash, err := common.GetImagePhashFromReader(bytes.NewReader(file)) + if err != nil { + return "", file, err + } + pictures, err := service.GetPicturesByHashHammingDistance(ctx, hash, 10) + if err != nil { + return "", file, err + } + if len(pictures) == 0 { + return "", file, errors.New("not found similar pictures by hash") + } + picture := pictures[0] + artworkID, err := primitive.ObjectIDFromHex(picture.ArtworkID) + if err != nil { + return "", file, err + } + artwork, err := service.GetArtworkByID(ctx, artworkID) + if err != nil { + return "", file, err + } + return artwork.SourceURL, file, nil +} + +func handleSendResultArtworks(artworks []*types.Artwork, message telego.Message, bot *telego.Bot) { + inputMedias := make([]telego.InputMedia, 0, len(artworks)) + for _, artwork := range artworks { + picture := artwork.Pictures[0] + var file telego.InputFile + if picture.TelegramInfo != nil && picture.TelegramInfo.PhotoFileID != "" { + file = telegoutil.FileFromID(picture.TelegramInfo.PhotoFileID) + } else { + photoURL := fmt.Sprintf("%s/?url=%s&w=2560&h=2560&we&output=jpg", config.Cfg.WSRVURL, picture.Original) + file = telegoutil.FileFromURL(photoURL) + } + caption := fmt.Sprintf("%s", artwork.SourceURL, common.EscapeHTML(artwork.Title)) + inputMedias = append(inputMedias, telegoutil.MediaPhoto(file).WithCaption(caption).WithParseMode(telego.ModeHTML)) + } + mediaGroup := telegoutil.MediaGroup(message.Chat.ChatID(), inputMedias...).WithReplyParameters(&telego.ReplyParameters{ + MessageID: message.MessageID, + ChatID: message.Chat.ChatID(), + }) + _, err := bot.SendMediaGroup(mediaGroup) + if err != nil { + common.Logger.Errorf("发送图片失败: %s", err) + } +}