Skip to content

Commit

Permalink
feat: 加入渠道加权随机功能
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Dec 27, 2023
1 parent 1a8a246 commit bdd611f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 8 deletions.
5 changes: 5 additions & 0 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ func GetRandomString(length int) string {
return string(key)
}

func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano())
return rand.Intn(max)
}

func GetTimestamp() int64 {
return time.Now().Unix()
}
Expand Down
35 changes: 30 additions & 5 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Ability struct {
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
}

func GetGroupModels(group string) []string {
Expand All @@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
}

func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var abilities []Ability
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
Expand All @@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
if len(abilities) > 0 {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight
}
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
}
} else {
return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
}

Expand All @@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
}
abilities = append(abilities, ability)
}
Expand Down
26 changes: 24 additions & 2 deletions model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
model = "gpt-4-gizmo-*"
}

// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
Expand All @@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
}

if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}

// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)

// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
}

func CacheGetChannel(id int) (*Channel, error) {
Expand Down
7 changes: 7 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
return *channel.Priority
}

func (channel *Channel) GetWeight() int {
if channel.Weight == nil {
return 0
}
return int(*channel.Weight)
}

func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
Expand Down
21 changes: 20 additions & 1 deletion web/src/components/ChannelsTable.js
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ const ChannelsTable = () => {
<div>
<InputNumber
style={{width: 70}}
name='name'
name='priority'
onChange={value => {
manageChannel(record.id, 'priority', record, value);
}}
Expand All @@ -174,6 +174,25 @@ const ChannelsTable = () => {
);
},
},
{
title: '权重',
dataIndex: 'weight',
render: (text, record, index) => {
return (
<div>
<InputNumber
style={{width: 70}}
name='weight'
onChange={value => {
manageChannel(record.id, 'weight', record, value);
}}
defaultValue={record.weight}
min={0}
/>
</div>
);
},
},
{
title: '',
dataIndex: 'operate',
Expand Down

0 comments on commit bdd611f

Please sign in to comment.