diff --git a/auth/auth.go b/auth/auth.go index 343cc81..80d9690 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,7 +3,6 @@ package auth import ( "crypto/rand" "encoding/hex" - "errors" "time" "goweb/config" @@ -29,6 +28,13 @@ func (e NotAuthenticatedError) Error() string { return "user not authenticated" } +type InvalidTokenError struct{} + +// Error implements the error interface. +func (e InvalidTokenError) Error() string { + return "invalid token" +} + type Client struct { config *config.Config orm *ent.Client @@ -115,30 +121,31 @@ func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (strin return token, pt, err } -func (c *Client) GetValidPasswordToken(ctx echo.Context, token string) (*ent.PasswordToken, error) { - // Hash the token in order to match in the database - hash, err := c.HashPassword(token) - if err != nil { - return nil, err - } +func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) { + // Ensure expired tokens are never returned + expiration := time.Now().Add(-c.config.App.PasswordTokenExpiration) - // Query to find a matching token - pt, err := c.orm.PasswordToken. + // Query to find all tokens for te user that haven't expired + // We need to get all of them in order to properly match the token to the hashes + pts, err := c.orm.PasswordToken. Query(). - Where(passwordtoken.Hash(hash)). - First(ctx.Request().Context()) + Where(passwordtoken.HasUserWith(user.ID(userID))). + Where(passwordtoken.CreatedAtGTE(expiration)). + All(ctx.Request().Context()) if err != nil { ctx.Logger().Error(err) return nil, err } - // Check if the token is no longer valid - if pt.CreatedAt.Before(time.Now().Add(-c.config.App.PasswordTokenExpiration)) { - return nil, errors.New("token has expired") + // Check all tokens for a hash match + for _, pt := range pts { + if err := c.CheckPassword(token, pt.Hash); err == nil { + return pt, nil + } } - return pt, nil + return nil, InvalidTokenError{} } func (c *Client) RandomToken(length int) string { diff --git a/middleware/auth.go b/middleware/auth.go index 32cd5eb..58a834d 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "strconv" "goweb/auth" "goweb/context" @@ -34,16 +35,22 @@ func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc { func LoadValidPasswordToken(authClient *auth.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - tokenParam := c.Param("password_token") - if tokenParam == "" { - c.Logger().Warn("missing password token path parameter") + userID, err := strconv.Atoi(c.Param("user")) + if err != nil { return echo.NewHTTPError(http.StatusNotFound, "Not found") } - token, err := authClient.GetValidPasswordToken(c, tokenParam) - if err != nil { + tokenParam := c.Param("password_token") + + token, err := authClient.GetValidPasswordToken(c, tokenParam, userID) + switch err.(type) { + case nil: + case auth.InvalidTokenError: msg.Warning(c, "The link is either invalid or has expired. Please request a new one.") return c.Redirect(http.StatusFound, c.Echo().Reverse("forgot_password")) + default: + c.Logger().Error(err) + return echo.NewHTTPError(http.StatusInternalServerError, "Internal server error") } c.Set(context.PasswordTokenKey, token) diff --git a/routes/forgot_password.go b/routes/forgot_password.go index 0994900..a7223ce 100644 --- a/routes/forgot_password.go +++ b/routes/forgot_password.go @@ -66,7 +66,7 @@ func (f *ForgotPassword) Post(c echo.Context) error { u, err := f.Container.ORM.User. Query(). Where(user.Email(form.Email)). - First(c.Request().Context()) + Only(c.Request().Context()) if err != nil { switch err.(type) { diff --git a/routes/login.go b/routes/login.go index 3da625c..277a3b8 100644 --- a/routes/login.go +++ b/routes/login.go @@ -61,7 +61,7 @@ func (l *Login) Post(c echo.Context) error { u, err := l.Container.ORM.User. Query(). Where(user.Email(form.Email)). - First(c.Request().Context()) + Only(c.Request().Context()) if err != nil { switch err.(type) { diff --git a/routes/register.go b/routes/register.go index 264223b..a0282c7 100644 --- a/routes/register.go +++ b/routes/register.go @@ -42,6 +42,8 @@ func (r *Register) Post(c echo.Context) error { return r.Get(c) } + // TODO: Validation for dupe email addresses + // Parse the form values form := new(RegisterForm) if err := c.Bind(form); err != nil { diff --git a/routes/router.go b/routes/router.go index 7bc5aea..c5b2d27 100644 --- a/routes/router.go +++ b/routes/router.go @@ -106,6 +106,6 @@ func userRoutes(c *container.Container, g *echo.Group, ctr controller.Controller resetGroup := noAuth.Group("/password/reset", middleware.LoadValidPasswordToken(c.Auth)) reset := ResetPassword{Controller: ctr} - resetGroup.GET("/token/:password_token", reset.Get).Name = "reset_password" - resetGroup.POST("/token/:password_token", reset.Post).Name = "reset_password.post" + resetGroup.GET("/token/:user/:password_token", reset.Get).Name = "reset_password" + resetGroup.POST("/token/:user/:password_token", reset.Post).Name = "reset_password.post" }