package controllers import ( "fmt" "io" "log" "net/http" "os" "regexp" "time" "github.com/BurntSushi/toml" "github.com/gin-gonic/gin" "github.com/yxzzy-wtf/gin-gonic-prepack/config" "github.com/yxzzy-wtf/gin-gonic-prepack/util" ) type RuleConfig struct { Rules []ruleDescription } type ruleDescription struct { Match string `toml:"match"` Seconds int `toml:"seconds"` Max int `toml:"max"` } type rule struct { duration time.Duration limit int } type bucket struct { rules *map[string]rule access map[string]int } func (b *bucket) take(resource string) bool { r, ex := (*b.rules)[resource] if !ex { // does not exist, forced to try match on regex? regexMatched := false for attemptMatch, attemptRes := range *b.rules { match, _ := regexp.MatchString("^"+attemptMatch+"$", resource) if match { resource = attemptMatch r = attemptRes regexMatched = true break } } if !regexMatched { // Default to Global log.Printf("defaulting %v to global\n", resource) resource = "" r = (*b.rules)[resource] } } max := r.limit duration := r.duration remaining, ex := b.access[resource] if !ex { b.access[resource] = max remaining = max } if remaining > 0 { remaining = remaining - 1 b.access[resource] = remaining go func(b *bucket, res string, d time.Duration) { time.Sleep(d) b.access[resource] = b.access[resource] + 1 }(b, resource, duration) return true } return false } type megabucket struct { buckets map[string]bucket rules map[string]rule } func (m *megabucket) loadFromConfig(filename string) { file, err := os.Open(filename) if err != nil { panic(err) } defer file.Close() rules := RuleConfig{} b, err := io.ReadAll(file) if err != nil { panic(err) } err = toml.Unmarshal(b, &rules) if err != nil { panic(err) } for _, r := range rules.Rules { fmt.Printf("Loading ratelimit rule: %+v\n", r) m.rules[r.Match] = rule{duration: time.Second * time.Duration(r.Seconds), limit: r.Max} } } func (m *megabucket) take(signature string, resource string) bool { b, ex := m.buckets[signature] if !ex { b = bucket{ rules: &m.rules, access: map[string]int{}, } m.buckets[signature] = b } return b.take(resource) } var unauthed = megabucket{ buckets: map[string]bucket{}, rules: map[string]rule{}, } var unauthLoaded = false /** * Applies rate limiting to unauthorized actors based on their IP address. * Imperfect, but better than a stab to the eye with a blunt pencil. */ func UnauthRateLimit() gin.HandlerFunc { return func(c *gin.Context) { if !unauthLoaded { panic("Unauthed rate limits not loaded") } ip := c.ClientIP() if !unauthed.take(ip, c.Request.Method+":"+c.FullPath()) { c.AbortWithStatus(http.StatusTooManyRequests) return } } } var authed = megabucket{ buckets: map[string]bucket{}, rules: map[string]rule{}, } var authLoaded = false /** * Authorized rate limit. Using the UID of the authorized user as the * accessor signature, rate limit based on the preexisting rules. */ func AuthedRateLimit() gin.HandlerFunc { return func(c *gin.Context) { if !authLoaded { panic("Authed rate limits not loaded") } pif, exists := c.Get("principal") p := pif.(util.PrincipalInfo) if !exists { c.AbortWithStatus(http.StatusUnauthorized) return } if !authed.take(p.Uid.String(), c.Request.Method+":"+c.FullPath()) { c.AbortWithStatus(http.StatusTooManyRequests) return } } } func LoadRateLimits() { authed.loadFromConfig(config.GetConfigPath(config.Config().AuthedRateLimitConfig)) authLoaded = true unauthed.loadFromConfig(config.GetConfigPath(config.Config().UnauthedRateLimitConfig)) unauthLoaded = true }