Include password token entity ID in reset URL in order to prevent loading all tokens.

This commit is contained in:
mikestefanello 2022-01-27 08:44:12 -05:00
parent 7a1a01d43e
commit 6546418052
7 changed files with 50 additions and 35 deletions

View file

@ -277,12 +277,13 @@ The generated code is extremely flexible and impressive. An example to highlight
```go ```go
entity, err := c.ORM.PasswordToken. entity, err := c.ORM.PasswordToken.
Query(). Query().
Where(passwordtoken.ID(tokenID)).
Where(passwordtoken.HasUserWith(user.ID(userID))). Where(passwordtoken.HasUserWith(user.ID(userID))).
Where(passwordtoken.CreatedAtGTE(expiration)). Where(passwordtoken.CreatedAtGTE(expiration)).
All(ctx.Request().Context()) Only(ctx.Request().Context())
``` ```
This executes a database query to return all _password token_ entities that belong to a user with a given ID and have a _created at_ timestamp field that is greater than or equal to a given time. This executes a database query to return the _password token_ entity with a given ID that belong to a user with a given ID and has a _created at_ timestamp field that is greater than or equal to a given time.
## Sessions ## Sessions
@ -326,11 +327,11 @@ Users can reset their password in a secure manner by issuing a new password toke
Tokens have a configurable expiration. By default, they expire within 1 hour. This can be controlled in the `config` package. The expiration of the token is not stored in the database, but rather is used only when tokens are loaded for potential usage. This allows you to change the expiration duration and affect existing tokens. Tokens have a configurable expiration. By default, they expire within 1 hour. This can be controlled in the `config` package. The expiration of the token is not stored in the database, but rather is used only when tokens are loaded for potential usage. This allows you to change the expiration duration and affect existing tokens.
Since the actual tokens are not stored in the database, the reset URL must contain the user's ID. Using that, `GetValidPasswordToken()` will load all non-expired _password token_ entities belonging to the user, and use `bcrypt` to determine if the token in the URL matches any of the stored hashes. Since the actual tokens are not stored in the database, the reset URL must contain the user and password token ID. Using that, `GetValidPasswordToken()` will load a matching, non-expired _password token_ entity belonging to the user, and use `bcrypt` to determine if the token in the URL matches stored hash of the password token entity.
Once a user claims a valid password token, all tokens for that user should be deleted using `DeletePasswordTokens()`. Once a user claims a valid password token, all tokens for that user should be deleted using `DeletePasswordTokens()`.
Routes are provided to request a password reset email at `user/password` and to reset your password at `user/password/reset/token/:uid/:password_token`. Routes are provided to request a password reset email at `user/password` and to reset your password at `user/password/reset/token/:user/:password_token/:token`.
### Registration ### Registration

View file

@ -2,6 +2,7 @@ package middleware
import ( import (
"net/http" "net/http"
"strconv"
"github.com/mikestefanello/pagoda/context" "github.com/mikestefanello/pagoda/context"
"github.com/mikestefanello/pagoda/ent" "github.com/mikestefanello/pagoda/ent"
@ -48,7 +49,19 @@ func LoadValidPasswordToken(authClient *services.AuthClient) echo.MiddlewareFunc
} }
usr := c.Get(context.UserKey).(*ent.User) usr := c.Get(context.UserKey).(*ent.User)
token, err := authClient.GetValidPasswordToken(c, c.Param("password_token"), usr.ID) // Extract the token ID
tokenID, err := strconv.Atoi(c.Param("password_token"))
if err != nil {
return echo.NewHTTPError(http.StatusNotFound)
}
// Attempt to load a valid password token
token, err := authClient.GetValidPasswordToken(
c,
usr.ID,
tokenID,
c.Param("token"),
)
switch err.(type) { switch err.(type) {
case nil: case nil:

View file

@ -79,17 +79,17 @@ func TestLoadValidPasswordToken(t *testing.T) {
err := tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth)) err := tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
tests.AssertHTTPErrorCode(t, err, http.StatusInternalServerError) tests.AssertHTTPErrorCode(t, err, http.StatusInternalServerError)
// Add user context but no password token and expect a redirect // Add user and password token context but no token and expect a redirect
ctx.SetParamNames("user") ctx.SetParamNames("user", "password_token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID)) ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1")
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM)) _ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth)) err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusFound, ctx.Response().Status) assert.Equal(t, http.StatusFound, ctx.Response().Status)
// Add user context and invalid password token and expect a redirect // Add user context and invalid password token and expect a redirect
ctx.SetParamNames("user", "password_token") ctx.SetParamNames("user", "password_token", "token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "faketoken") ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1", "faketoken")
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM)) _ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth)) err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.NoError(t, err) assert.NoError(t, err)
@ -100,8 +100,8 @@ func TestLoadValidPasswordToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Add user and valid password token // Add user and valid password token
ctx.SetParamNames("user", "password_token") ctx.SetParamNames("user", "password_token", "token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), token) ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), fmt.Sprintf("%d", pt.ID), token)
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM)) _ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth)) err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.Nil(t, err) assert.Nil(t, err)

View file

@ -76,7 +76,7 @@ func (c *ForgotPassword) Post(ctx echo.Context) error {
} }
// Generate the token // Generate the token
token, _, err := c.Container.Auth.GeneratePasswordResetToken(ctx, u.ID) token, pt, err := c.Container.Auth.GeneratePasswordResetToken(ctx, u.ID)
if err != nil { if err != nil {
return c.Fail(ctx, err, "error generating password reset token") return c.Fail(ctx, err, "error generating password reset token")
} }
@ -84,7 +84,7 @@ func (c *ForgotPassword) Post(ctx echo.Context) error {
ctx.Logger().Infof("generated password reset token for user %d", u.ID) ctx.Logger().Infof("generated password reset token for user %d", u.ID)
// Email the user // Email the user
url := ctx.Echo().Reverse("reset_password", u.ID, token) url := ctx.Echo().Reverse("reset_password", u.ID, pt.ID, token)
err = c.Container.Mail. err = c.Container.Mail.
Compose(). Compose().
To(u.Email). To(u.Email).

View file

@ -104,6 +104,6 @@ func userRoutes(c *services.Container, g *echo.Group, ctr controller.Controller)
middleware.LoadValidPasswordToken(c.Auth), middleware.LoadValidPasswordToken(c.Auth),
) )
reset := ResetPassword{Controller: ctr} reset := ResetPassword{Controller: ctr}
resetGroup.GET("/token/:user/:password_token", reset.Get).Name = "reset_password" resetGroup.GET("/token/:user/:password_token/:token", reset.Get).Name = "reset_password"
resetGroup.POST("/token/:user/:password_token", reset.Post).Name = "reset_password.post" resetGroup.POST("/token/:user/:password_token/:token", reset.Post).Name = "reset_password.post"
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/mikestefanello/pagoda/config" "github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/context"
"github.com/mikestefanello/pagoda/ent" "github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/passwordtoken" "github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user" "github.com/mikestefanello/pagoda/ent/user"
@ -146,32 +147,32 @@ func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (s
return token, pt, err return token, pt, err
} }
// GetValidPasswordToken returns a valid password token entity for a given user and a given token. // GetValidPasswordToken returns a valid, non-expired password token entity for a given user, token ID and token.
// Since the actual token is not stored in the database for security purposes, all non-expired token entities // Since the actual token is not stored in the database for security purposes, if a matching password token entity is
// are fetched from the database belonging to the requesting user and a hash of the provided token is compared // found a hash of the provided token is compared with the hash stored in the database in order to validate.
// with the hash stored in the database. func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, userID, tokenID int, token string) (*ent.PasswordToken, error) {
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) {
// Ensure expired tokens are never returned // Ensure expired tokens are never returned
expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration) expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration)
// Query to find all tokens for te user that haven't expired // Query to find a password token entity that matches the given user and token ID
// We need to get all of them in order to properly match the token to the hashes pt, err := c.orm.PasswordToken.
pts, err := c.orm.PasswordToken.
Query(). Query().
Where(passwordtoken.ID(tokenID)).
Where(passwordtoken.HasUserWith(user.ID(userID))). Where(passwordtoken.HasUserWith(user.ID(userID))).
Where(passwordtoken.CreatedAtGTE(expiration)). Where(passwordtoken.CreatedAtGTE(expiration)).
All(ctx.Request().Context()) Only(ctx.Request().Context())
if err != nil { switch err.(type) {
ctx.Logger().Error(err) case *ent.NotFoundError:
return nil, err case nil:
} // Check the token for a hash match
// Check all tokens for a hash match
for _, pt := range pts {
if err := c.CheckPassword(token, pt.Hash); err == nil { if err := c.CheckPassword(token, pt.Hash); err == nil {
return pt, nil return pt, nil
} }
default:
if !context.IsCanceledError(err) {
return nil, err
}
} }
return nil, InvalidPasswordTokenError{} return nil, InvalidPasswordTokenError{}

View file

@ -59,13 +59,13 @@ func TestAuthClient_GeneratePasswordResetToken(t *testing.T) {
func TestAuthClient_GetValidPasswordToken(t *testing.T) { func TestAuthClient_GetValidPasswordToken(t *testing.T) {
// Check that a fake token is not valid // Check that a fake token is not valid
_, err := c.Auth.GetValidPasswordToken(ctx, "faketoken", usr.ID) _, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, 1, "faketoken")
assert.Error(t, err) assert.Error(t, err)
// Generate a valid token and check that it is returned // Generate a valid token and check that it is returned
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err) require.NoError(t, err)
pt2, err := c.Auth.GetValidPasswordToken(ctx, token, usr.ID) pt2, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, pt.ID, pt2.ID) assert.Equal(t, pt.ID, pt2.ID)
@ -78,7 +78,7 @@ func TestAuthClient_GetValidPasswordToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Expired tokens should not be valid // Expired tokens should not be valid
_, err = c.Auth.GetValidPasswordToken(ctx, token, usr.ID) _, err = c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
assert.Error(t, err) assert.Error(t, err)
} }