Files
gin-gonic-prepack/controllers/ratelimit.go
2025-10-13 20:53:49 +02:00

188 lines
3.7 KiB
Go

package controllers
import (
"fmt"
"io"
"log"
"net/http"
"os"
"regexp"
"time"
"github.com/BurntSushi/toml"
"github.com/gin-gonic/gin"
"github.com/yxzzy-wtf/gin-gonic-prepack/config"
"github.com/yxzzy-wtf/gin-gonic-prepack/util"
)
type RuleConfig struct {
Rules []ruleDescription
}
type ruleDescription struct {
Match string `toml:"match"`
Seconds int `toml:"seconds"`
Max int `toml:"max"`
}
type rule struct {
duration time.Duration
limit int
}
type bucket struct {
rules *map[string]rule
access map[string]int
}
func (b *bucket) take(resource string) bool {
r, ex := (*b.rules)[resource]
if !ex {
// does not exist, forced to try match on regex?
regexMatched := false
for attemptMatch, attemptRes := range *b.rules {
match, _ := regexp.MatchString("^"+attemptMatch+"$", resource)
if match {
resource = attemptMatch
r = attemptRes
regexMatched = true
break
}
}
if !regexMatched {
// Default to Global
log.Printf("defaulting %v to global\n", resource)
resource = ""
r = (*b.rules)[resource]
}
}
max := r.limit
duration := r.duration
remaining, ex := b.access[resource]
if !ex {
b.access[resource] = max
remaining = max
}
if remaining > 0 {
remaining = remaining - 1
b.access[resource] = remaining
go func(b *bucket, res string, d time.Duration) {
time.Sleep(d)
b.access[resource] = b.access[resource] + 1
}(b, resource, duration)
return true
}
return false
}
type megabucket struct {
buckets map[string]bucket
rules map[string]rule
}
func (m *megabucket) loadFromConfig(filename string) {
file, err := os.Open(filename)
if err != nil {
panic(err)
}
defer file.Close()
rules := RuleConfig{}
b, err := io.ReadAll(file)
if err != nil {
panic(err)
}
err = toml.Unmarshal(b, &rules)
if err != nil {
panic(err)
}
for _, r := range rules.Rules {
fmt.Printf("Loading ratelimit rule: %+v\n", r)
m.rules[r.Match] = rule{duration: time.Second * time.Duration(r.Seconds), limit: r.Max}
}
}
func (m *megabucket) take(signature string, resource string) bool {
b, ex := m.buckets[signature]
if !ex {
b = bucket{
rules: &m.rules,
access: map[string]int{},
}
m.buckets[signature] = b
}
return b.take(resource)
}
var unauthed = megabucket{
buckets: map[string]bucket{},
rules: map[string]rule{},
}
var unauthLoaded = false
/**
* Applies rate limiting to unauthorized actors based on their IP address.
* Imperfect, but better than a stab to the eye with a blunt pencil.
*/
func UnauthRateLimit() gin.HandlerFunc {
return func(c *gin.Context) {
if !unauthLoaded {
panic("Unauthed rate limits not loaded")
}
ip := c.ClientIP()
if !unauthed.take(ip, c.Request.Method+":"+c.FullPath()) {
c.AbortWithStatus(http.StatusTooManyRequests)
return
}
}
}
var authed = megabucket{
buckets: map[string]bucket{},
rules: map[string]rule{},
}
var authLoaded = false
/**
* Authorized rate limit. Using the UID of the authorized user as the
* accessor signature, rate limit based on the preexisting rules.
*/
func AuthedRateLimit() gin.HandlerFunc {
return func(c *gin.Context) {
if !authLoaded {
panic("Authed rate limits not loaded")
}
pif, exists := c.Get("principal")
p := pif.(util.PrincipalInfo)
if !exists {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
if !authed.take(p.Uid.String(), c.Request.Method+":"+c.FullPath()) {
c.AbortWithStatus(http.StatusTooManyRequests)
return
}
}
}
func LoadRateLimits() {
authed.loadFromConfig(config.GetConfigPath(config.Config().AuthedRateLimitConfig))
authLoaded = true
unauthed.loadFromConfig(config.GetConfigPath(config.Config().UnauthedRateLimitConfig))
unauthLoaded = true
}