From 65c9309f439334dc7ac50b6f9c375cd0dc41d57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99PiperYxzzy?= Date: Wed, 4 May 2022 20:33:54 +0200 Subject: [PATCH] Simple rate-limiting added --- controllers/ratelimit.go | 108 +++++++++++++++++++++++++++++++++++++++ main.go | 18 ++++--- 2 files changed, 118 insertions(+), 8 deletions(-) create mode 100644 controllers/ratelimit.go diff --git a/controllers/ratelimit.go b/controllers/ratelimit.go new file mode 100644 index 0000000..d4323a2 --- /dev/null +++ b/controllers/ratelimit.go @@ -0,0 +1,108 @@ +package controllers + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/yxzzy-wtf/gin-gonic-prepack/util" +) + +type rule struct { + duration time.Duration + limit int +} + +type bucket struct { + rules map[string]rule + access map[string]int +} + +func (b *bucket) take(resource string) bool { + r, ex := b.rules[resource] + if !ex { + resource = "*" + r = b.rules[resource] + } + max := r.limit + duration := r.duration + + remaining, ex := b.access[resource] + if !ex { + b.access[resource] = max + remaining = max + } + + if remaining > 0 { + remaining = remaining - 1 + b.access[resource] = remaining + + go func(b *bucket, res string, d time.Duration) { + time.Sleep(d) + b.access[resource] = b.access[resource] + 1 + }(b, resource, duration) + + return true + } + + return false +} + +type megabucket struct { + buckets map[string]bucket + rules map[string]rule +} + +func (m *megabucket) take(signature string, resource string) bool { + b, ex := m.buckets[signature] + if !ex { + b = bucket{ + rules: m.rules, + access: map[string]int{}, + } + m.buckets[signature] = b + } + + return b.take(resource) +} + +var unauthed = megabucket{ + buckets: map[string]bucket{}, + rules: map[string]rule{ + "*": {duration: time.Second * 10, limit: 20}, + }, +} + +func UnauthRateLimit() gin.HandlerFunc { + return func(c *gin.Context) { + ip := c.ClientIP() + + if !unauthed.take(ip, "") { + c.AbortWithStatus(http.StatusTooManyRequests) + return + } + } +} + +var authed = megabucket{ + buckets: map[string]bucket{}, + rules: map[string]rule{ + "*": {duration: time.Second * 10, limit: 5}, + }, +} + +func AuthedRateLimit() gin.HandlerFunc { + return func(c *gin.Context) { + pif, exists := c.Get("principal") + p := pif.(util.PrincipalInfo) + if !exists { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + if !authed.take(p.Uid.String(), c.FullPath()) { + c.AbortWithStatus(http.StatusTooManyRequests) + return + } + } +} diff --git a/main.go b/main.go index 210dee2..7a669d0 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" "github.com/yxzzy-wtf/gin-gonic-prepack/config" + "github.com/yxzzy-wtf/gin-gonic-prepack/controllers" "github.com/yxzzy-wtf/gin-gonic-prepack/controllers/core" "github.com/yxzzy-wtf/gin-gonic-prepack/database" "github.com/yxzzy-wtf/gin-gonic-prepack/models" @@ -42,21 +43,22 @@ func main() { v1 := r.Group("/v1") // Ping functionality - v1.GET("/doot", core.Doot()) + v1.GET("/doot", controllers.UnauthRateLimit(), core.Doot()) // Standard user signup, verify, login and forgot/reset pw - v1.POST("/signup", core.UserSignup()) - v1.POST("/login", core.UserLogin()) - v1.GET("/verify", core.UserVerify()) - v1.POST("/forgot", core.UserForgotPassword()) - v1.POST("/reset", core.UserResetForgottenPassword()) - v1Sec := v1.Group("/sec", core.UserAuth()) + v1.POST("/signup", controllers.UnauthRateLimit(), core.UserSignup()) + v1.POST("/login", controllers.UnauthRateLimit(), core.UserLogin()) + v1.GET("/verify", controllers.UnauthRateLimit(), core.UserVerify()) + v1.POST("/forgot", controllers.UnauthRateLimit(), core.UserForgotPassword()) + v1.POST("/reset", controllers.UnauthRateLimit(), core.UserResetForgottenPassword()) + + v1Sec := v1.Group("/sec", core.UserAuth(), controllers.AuthedRateLimit()) v1Sec.GET("/doot", core.Doot()) v1Sec.GET("/2fa-doot", core.LiveTwoFactor(), core.Doot()) // Administrative login - v1.POST("/admin", core.AdminLogin()) + v1.POST("/admin", controllers.UnauthRateLimit(), core.AdminLogin()) v1Admin := v1.Group("/adm", core.AdminAuth()) v1Admin.GET("/doot", core.Doot())