diff --git a/common/utils.go b/common/utils.go index f9f3bc25..4ec0b2e9 100644 --- a/common/utils.go +++ b/common/utils.go @@ -168,6 +168,11 @@ func GetRandomString(length int) string { return string(key) } +func GetRandomInt(max int) int { + //rand.Seed(time.Now().UnixNano()) + return rand.Intn(max) +} + func GetTimestamp() int64 { return time.Now().Unix() } diff --git a/model/ability.go b/model/ability.go index f060991d..5679b61e 100644 --- a/model/ability.go +++ b/model/ability.go @@ -11,6 +11,7 @@ type Ability struct { ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` + Weight uint `json:"weight" gorm:"default:0;index"` } func GetGroupModels(group string) []string { @@ -25,7 +26,7 @@ func GetGroupModels(group string) []string { } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - ability := Ability{} + var abilities []Ability groupCol := "`group`" trueVal := "1" if common.UsingPostgreSQL { @@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) if common.UsingSQLite || common.UsingPostgreSQL { - err = channelQuery.Order("RANDOM()").First(&ability).Error + err = channelQuery.Order("weight DESC").Find(&abilities).Error } else { - err = channelQuery.Order("RAND()").First(&ability).Error + err = channelQuery.Order("weight DESC").Find(&abilities).Error } if err != nil { return nil, err } channel := Channel{} - channel.Id = ability.ChannelId - err = DB.First(&channel, "id = ?", ability.ChannelId).Error + if len(abilities) > 0 { + // Randomly choose one + weightSum := uint(0) + for _, ability_ := range abilities { + weightSum += ability_.Weight + } + if weightSum == 0 { + // All weight is 0, randomly choose one + channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId + } else { + // Randomly choose one + weight := common.GetRandomInt(int(weightSum)) + for _, ability_ := range abilities { + weight -= int(ability_.Weight) + //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) + if weight <= 0 { + channel.Id = ability_.ChannelId + break + } + } + } + } else { + return nil, nil + } + err = DB.First(&channel, "id = ?", channel.Id).Error return &channel, err } @@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error { ChannelId: channel.Id, Enabled: channel.Status == common.ChannelStatusEnabled, Priority: channel.Priority, + Weight: uint(channel.GetWeight()), } abilities = append(abilities, ability) } diff --git a/model/cache.go b/model/cache.go index c575ad95..975ceddf 100644 --- a/model/cache.go +++ b/model/cache.go @@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error model = "gpt-4-gizmo-*" } + // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } @@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } } - idx := rand.Intn(endIdx) - return channels[idx], nil + // Calculate the total weight of all channels up to endIdx + totalWeight := 0 + for _, channel := range channels[:endIdx] { + totalWeight += channel.GetWeight() + } + + if totalWeight == 0 { + // If all weights are 0, select a channel randomly + return channels[rand.Intn(endIdx)], nil + } + + // Generate a random value in the range [0, totalWeight) + randomWeight := rand.Intn(totalWeight) + + // Find a channel based on its weight + for _, channel := range channels[:endIdx] { + randomWeight -= channel.GetWeight() + if randomWeight <= 0 { + return channel, nil + } + } + // return the last channel if no channel is found + return channels[endIdx-1], nil } func CacheGetChannel(id int) (*Channel, error) { diff --git a/model/channel.go b/model/channel.go index 1f7dd2de..7cdc2f6a 100644 --- a/model/channel.go +++ b/model/channel.go @@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 { return *channel.Priority } +func (channel *Channel) GetWeight() int { + if channel.Weight == nil { + return 0 + } + return int(*channel.Weight) +} + func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 1a9e11f2..0b07420c 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -163,7 +163,7 @@ const ChannelsTable = () => {
{ manageChannel(record.id, 'priority', record, value); }} @@ -174,6 +174,25 @@ const ChannelsTable = () => { ); }, }, + { + title: '权重', + dataIndex: 'weight', + render: (text, record, index) => { + return ( +
+ { + manageChannel(record.id, 'weight', record, value); + }} + defaultValue={record.weight} + min={0} + /> +
+ ); + }, + }, { title: '', dataIndex: 'operate',