diff --git a/go.mod b/go.mod index 432aeb7c..7ba7ae0f 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( go.uber.org/atomic v1.4.0 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 golang.org/x/net v0.0.0-20200226121028-0de0cce0169b + golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e k8s.io/gengo v0.0.0-20201113003025-83324d819ded nhooyr.io/websocket v1.8.6 ) diff --git a/go.sum b/go.sum index 10e2b96c..ebb4819a 100644 --- a/go.sum +++ b/go.sum @@ -83,6 +83,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/ratelimit_roundtripper.go b/ratelimit_roundtripper.go new file mode 100644 index 00000000..eeda7cf9 --- /dev/null +++ b/ratelimit_roundtripper.go @@ -0,0 +1,155 @@ +package disgord +// +// import ( +// "net/http" +// "net/url" +// "strings" +// "sync" +// +// "golang.org/x/time/rate" +// ) +// +// type DiscordURL url.URL +// +// func (d *DiscordURL) IsMajor() bool { +// for _, prefix := range []string{"guilds", "channels", "webhooks"} { +// if strings.HasPrefix(d.Path, prefix) { +// return true +// } +// } +// return false +// } +// +// func (d *DiscordURL) Hash(method string) string { +// matches := regexpURLSnowflakes.FindAllString(d.Path, -1) +// isMajor := d.IsMajor() +// buffer := d.Path +// for i := range matches { +// if i == 0 && isMajor { +// continue +// } +// +// buffer = strings.ReplaceAll(buffer, matches[i], "/{id}/") +// } +// +// // check for reaction endpoints, convert emoji identifier to {emoji} +// reactionPrefixMatch := regexpURLReactionPrefix.FindAllString(buffer, -1) +// if reactionPrefixMatch != nil { +// if regexpURLReactionEmoji.FindAllString(buffer, -1) != nil { +// reactionEmojis := regexpURLReactionEmojiSegment.FindAllString(buffer, -1) +// for i := range reactionEmojis { +// buffer = strings.ReplaceAll(buffer, reactionEmojis[i], "/reactions/{emoji}") +// } +// } else { +// // corner case for urls with emojis +// suffix := buffer[len(reactionPrefixMatch[0]):] +// until := len(suffix) +// for i, r := range suffix { +// if r == '/' { +// until = i +// break +// } +// } +// newSuffix := "{emoji}" + suffix[until:] +// buffer = buffer[:len(buffer)-len(suffix)] + newSuffix +// } +// } +// +// if strings.HasSuffix(buffer, "/") { +// buffer = buffer[:len(buffer)-1] +// } +// return method + ":" + buffer +// } +// +// type reqWaiterChan chan *rate.Limiter +// +// type RateLimit struct { +// Next http.RoundTripper +// +// vtableMu sync.RWMutex +// vtable map[string]string +// +// waitersMu sync.RWMutex +// waiters map[string]reqWaiterChan +// +// bucketsMy sync.RWMutex +// buckets map[string]*rate.Limiter +// } +// +// var _ http.RoundTripper = &RateLimit{} +// +// func (r *RateLimit) RoundTrip(req *http.Request) (resp *http.Response, err error) { +// if !r.isDiscordAPIRequest(req) { +// return r.Next.RoundTrip(req) +// } +// +// durl := DiscordURL(*req.URL) +// localHash := durl.Hash(req.Method) +// discordHash := r.discordHash(localHash) +// if discordHash == "" { +// waitChan := r.waiter(localHash, func() { +// resp, err := r.Next.RoundTrip(req) +// if err != nil { +// return nil, err +// } +// }) +// } +// +// return r.rateLimit(req, func() (*http.Response, error) { +// return r.Next.RoundTrip(req) +// }) +// } +// +// func (r *RateLimit) isDiscordAPIRequest(req *http.Request) bool { +// const DiscordAPIURLPrefix = "https://discord.com/api/v" +// return strings.HasPrefix(req.URL.String(), DiscordAPIURLPrefix) +// } +// +// func (r *RateLimit) discordHash(localHash string) string { +// r.vtableMu.RLock() +// defer r.vtableMu.RUnlock() +// +// if discordHash, ok := r.vtable[localHash]; ok { +// return discordHash +// } else { +// return "" +// } +// } +// +// func (r *RateLimit) waiter(localHash string) reqWaiterChan { +// r.waitersMu.RLock() +// defer r.waitersMu.RUnlock() +// +// if waiter, ok := r.waiters[localHash]; ok { +// return waiter +// } else { +// return nil +// } +// } +// +// func (r *RateLimit) bucket(hash string, setupBucket func() (discordHash string, bucket *rate.Limiter)) string { +// r.vtableMu.RLock() +// if discordHash, ok := r.vtable[hash]; ok { +// r.vtableMu.RUnlock() +// return discordHash +// } +// r.vtableMu.RUnlock() +// +// r.vtableMu.Lock() +// // check if another request got here faster +// if discordHash, ok := r.vtable[hash]; ok { +// r.vtableMu.Unlock() +// return discordHash +// } +// +// discordHash, bucket := setupBucket() +// r.vtable[hash] = discordHash +// +// // this should be updated once bucket information is delivered by discord +// r.vtable[hash] = hash +// return hash +// } +// +// func (r *RateLimit) bucket(hash string) *rate.Limiter { +// +// }