diff --git a/main.go b/main.go index 0eea2e0..337393d 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "fmt" "log" "net/http" - "time" "github.com/yxzzy-wtf/gin-gonic-prepack/database" "github.com/yxzzy-wtf/gin-gonic-prepack/models" @@ -86,20 +85,12 @@ func userLogin() gin.HandlerFunc { return } - err := u.CheckPassword(loginVals.Password) - if err != nil { - c.AbortWithStatus(http.StatusUnauthorized) - return - } - - err = u.ValidateTwoFactor(loginVals.TwoFactor, time.Now()) - if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{err.Error()}) - return - } - - if !u.Verified { - c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"not yet verified"}) + if err, returnErr := u.Login(loginVals.Password, loginVals.TwoFactor); err != nil { + if returnErr { + c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{err.Error()}) + } else { + c.AbortWithStatus(http.StatusUnauthorized) + } return } @@ -141,7 +132,33 @@ func userSignup() gin.HandlerFunc { func adminLogin() gin.HandlerFunc { return func(c *gin.Context) { + var loginVals login + if err := c.ShouldBind(&loginVals); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, failmsg{"requires username and password"}) + } + if loginVals.TwoFactor == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"admin access requires 2FA"}) + return + } + + a := models.Admin{} + if err := a.ByEmail(loginVals.UserKey); err != nil { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + if err, returnErr := a.Login(loginVals.Password, loginVals.TwoFactor); err != nil { + if returnErr { + c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{err.Error()}) + } else { + c.AbortWithStatus(http.StatusUnauthorized) + } + return + } + + jwt, maxAge := a.GetJwt() + c.SetCookie(JwtHeader, jwt, maxAge, ServicePath, ServiceDomain, true, true) } } @@ -149,10 +166,11 @@ func userAuth() gin.HandlerFunc { return func(c *gin.Context) { jwt := c.GetHeader(JwtHeader) if jwt == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"Requires `" + jwt + "` header"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"requires `" + JwtHeader + "` header"}) return } + c.AbortWithStatus(http.StatusUnauthorized) } } @@ -160,7 +178,7 @@ func adminAuth() gin.HandlerFunc { return func(c *gin.Context) { jwt := c.GetHeader(JwtHeader) if jwt == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"Requires `" + jwt + "` header"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, failmsg{"requires `" + JwtHeader + "` header"}) return } diff --git a/models/auth.go b/models/auth.go index 243d4d2..67afd62 100644 --- a/models/auth.go +++ b/models/auth.go @@ -21,6 +21,22 @@ func (a *Auth) SetPassword(pass string) error { return nil } +func (a *Auth) Login(pass string, tfCode string) (error, bool) { + if err := a.CheckPassword(pass); err != nil { + return err, false + } + + if err := a.ValidateTwoFactor(tfCode, time.Now()); err != nil { + return err, true + } + + if !a.Verified { + return errors.New("not yet verified"), true + } + + return nil, false +} + func (a *Auth) CheckPassword(pass string) error { return bcrypt.CompareHashAndPassword([]byte(a.PasswordHash), []byte(pass)) }