Skip to content

Commit

Permalink
feat: add support for searching similar artworks and enhance hybrid s…
Browse files Browse the repository at this point in the history
…earch functionality
  • Loading branch information
krau committed Jan 26, 2025
1 parent 2a3d580 commit c27e6ea
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 14 deletions.
7 changes: 4 additions & 3 deletions config/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ type searchConfig struct {
}

type meiliSearchConfig struct {
Host string `toml:"host" mapstructure:"host" json:"host" yaml:"host"`
Key string `toml:"key" mapstructure:"key" json:"key" yaml:"key"`
Index string `toml:"index" mapstructure:"index" json:"index" yaml:"index"`
Host string `toml:"host" mapstructure:"host" json:"host" yaml:"host"`
Key string `toml:"key" mapstructure:"key" json:"key" yaml:"key"`
Index string `toml:"index" mapstructure:"index" json:"index" yaml:"index"`
Embedder string `toml:"embedder" mapstructure:"embedder" json:"embedder" yaml:"embedder"`
}
9 changes: 3 additions & 6 deletions config/viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,14 @@ func InitConfig() {
viper.SetDefault("source.kemono.worker", 5)
viper.SetDefault("source.yandere.enable", true)
viper.SetDefault("source.nhentai.enable", true)

viper.SetDefault("source.pixiv.intervel", 60)
viper.SetDefault("source.pixiv.sleep", 1)

viper.SetDefault("source.twitter.fx_twitter_domain", "fxtwitter.com")
viper.SetDefault("source.twitter.sleep", 1)
viper.SetDefault("source.twitter.intervel", 60)

viper.SetDefault("storage.cache_dir", "./cache")
viper.SetDefault("storage.cache_ttl", 86400)
// viper.SetDefault("storage.original_type", "local")
// viper.SetDefault("storage.regular_type", "local")
// viper.SetDefault("storage.thumb_type", "local")
// viper.SetDefault("storage.local.enable", true)
viper.SetDefault("storage.local.path", "./manyacg")
viper.SetDefault("storage.alist.token_expire", 86400)
viper.SetDefault("storage.regular_format", "webp")
Expand All @@ -95,6 +89,9 @@ func InitConfig() {
viper.SetDefault("database.database", "manyacg")
viper.SetDefault("database.max_staleness", 120)

viper.SetDefault("search.meilisearch.index", "manyacg")
viper.SetDefault("search.meilisearch.embedder", "default")

if err := viper.ReadInConfig(); err != nil {
fmt.Printf("error when reading config: %s\n", err)
os.Exit(1)
Expand Down
3 changes: 2 additions & 1 deletion errs/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ var (

ErrAliasAlreadyUsed = errors.New("alias already used")

ErrNotEnabledHybridSearch = errors.New("hybrid search not enabled")
ErrSearchEngineUnavailable = errors.New("search engine unavailable")
ErrArtworksNotFound = errors.New("artworks not found")
)
47 changes: 44 additions & 3 deletions service/artwork_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
)

func HybridSearchArtworks(ctx context.Context, queryText string, hybridSemanticRatio float64, limit int64, options ...*types.AdapterOption) ([]*types.Artwork, error) {
func HybridSearchArtworks(ctx context.Context, queryText string, hybridSemanticRatio float64, offset, limit int64, options ...*types.AdapterOption) ([]*types.Artwork, error) {
if common.MeilisearchClient == nil {
return nil, errs.ErrNotEnabledHybridSearch
return nil, errs.ErrSearchEngineUnavailable
}
index := common.MeilisearchClient.Index(config.Cfg.Search.MeiliSearch.Index)
resp, err := index.SearchWithContext(ctx, queryText, &meilisearch.SearchRequest{
Limit: limit,
Offset: offset,
Limit: limit,
Hybrid: &meilisearch.SearchRequestHybrid{
Embedder: config.Cfg.Search.MeiliSearch.Embedder,
SemanticRatio: hybridSemanticRatio,
},
})
Expand Down Expand Up @@ -52,3 +54,42 @@ func HybridSearchArtworks(ctx context.Context, queryText string, hybridSemanticR
}
return adapter.ConvertToArtworks(ctx, artworkModels, options...)
}

func SearchSimilarArtworks(ctx context.Context, artworkIdStr string, offset, limit int64, options ...*types.AdapterOption) ([]*types.Artwork, error) {
if common.MeilisearchClient == nil {
return nil, errs.ErrSearchEngineUnavailable
}
index := common.MeilisearchClient.Index(config.Cfg.Search.MeiliSearch.Index)
var resp meilisearch.SimilarDocumentResult
if err := index.SearchSimilarDocumentsWithContext(ctx, &meilisearch.SimilarDocumentQuery{
Id: artworkIdStr,
Embedder: config.Cfg.Search.MeiliSearch.Embedder,
Offset: offset,
Limit: limit,
}, &resp); err != nil {
return nil, err
}
hits := resp.Hits
artworkSearchDocs := make([]*types.ArtworkSearchDocument, 0, len(hits))
hitsBytes, err := sonic.Marshal(hits)
if err != nil {
return nil, err
}
err = sonic.Unmarshal(hitsBytes, &artworkSearchDocs)
if err != nil {
return nil, err
}
artworkModels := make([]*types.ArtworkModel, 0, len(artworkSearchDocs))
for _, doc := range artworkSearchDocs {
objectID, err := primitive.ObjectIDFromHex(doc.ID)
if err != nil {
return nil, err
}
artworkModel, err := dao.GetArtworkByID(ctx, objectID)
if err != nil {
return nil, err
}
artworkModels = append(artworkModels, artworkModel)
}
return adapter.ConvertToArtworks(ctx, artworkModels, options...)
}
4 changes: 4 additions & 0 deletions telegram/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ var (
Command: "hybrid",
Description: "基于语义与关键字混合搜索作品",
},
{
Command: "similar",
Description: "获取与回复的图片相似的作品",
},
}

AdminCommands = []telego.BotCommand{
Expand Down
1 change: 1 addition & 0 deletions telegram/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func RegisterHandlers(hg *telegohandler.HandlerGroup) {
mg.HandleMessageCtx(CalculatePicture, telegohandler.CommandEqual("hash"))
mg.HandleMessageCtx(GetStats, telegohandler.CommandEqual("stats"))
mg.HandleMessageCtx(HybridSearchArtworks, telegohandler.CommandEqual("hybrid"))
mg.HandleMessageCtx(SearchSimilarArtworks, telegohandler.CommandEqual("similar"))

// Admin commands
mg.HandleMessageCtx(SetAdmin, telegohandler.CommandEqual("set_admin"))
Expand Down
55 changes: 54 additions & 1 deletion telegram/handlers/query_artwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M
}
queryText = strings.Join(args[:len(args)-1], " ")
}
artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 10)
artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 0, 10)
if err != nil {
common.Logger.Errorf("搜索失败: %s", err)
utils.ReplyMessage(bot, message, "搜索失败, 请联系管理员检查搜索引擎设置与状态")
Expand Down Expand Up @@ -137,3 +137,56 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M
}

}

func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego.Message) {
if common.MeilisearchClient == nil {
utils.ReplyMessage(bot, message, "搜索引擎不可用")
return
}
if message.ReplyToMessage == nil {
utils.ReplyMessage(bot, message, "请回复一张图片")
return
}
sourceURL := utils.FindSourceURLForMessage(message.ReplyToMessage)
if sourceURL == "" {
utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接")
return
}
artwork, err := service.GetArtworkByURL(ctx, sourceURL)
if err != nil || artwork == nil {
common.Logger.Errorf("获取作品信息失败: %s", err)
utils.ReplyMessage(bot, message, "获取作品信息失败")
return
}
artworks, err := service.SearchSimilarArtworks(ctx, artwork.ID, 0, 10)
if err != nil {
common.Logger.Errorf("搜索失败: %s", err)
utils.ReplyMessage(bot, message, "搜索失败")
return
}
if len(artworks) == 0 {
utils.ReplyMessage(bot, message, "未找到相似的作品")
return
}
if len(artworks) > 10 {
artworks = artworks[:10]
}
inputMedias := make([]telego.InputMedia, 0, len(artworks))
for _, artwork := range artworks {
picture := artwork.Pictures[0]
var file telego.InputFile
if picture.TelegramInfo != nil && picture.TelegramInfo.PhotoFileID != "" {
file = telegoutil.FileFromID(picture.TelegramInfo.PhotoFileID)
} else {
photoURL := fmt.Sprintf("%s/?url=%s&w=2560&h=2560&we&output=jpg", config.Cfg.WSRVURL, picture.Original)
file = telegoutil.FileFromURL(photoURL)
}
caption := fmt.Sprintf("<a href=\"%s\">%s</a>", artwork.SourceURL, common.EscapeHTML(artwork.Title))
inputMedias = append(inputMedias, telegoutil.MediaPhoto(file).WithCaption(caption).WithParseMode(telego.ModeHTML))
}
mediaGroup := telegoutil.MediaGroup(message.Chat.ChatID(), inputMedias...)
_, err = bot.SendMediaGroup(mediaGroup)
if err != nil {
common.Logger.Errorf("发送图片失败: %s", err)
}
}

0 comments on commit c27e6ea

Please sign in to comment.