Include password token entity ID in reset URL in order to prevent loading all tokens.

This commit is contained in:
mikestefanello 2022-01-27 08:44:12 -05:00
parent 7a1a01d43e
commit 6546418052
7 changed files with 50 additions and 35 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/golang-jwt/jwt"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/context"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user"
@ -146,32 +147,32 @@ func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (s
return token, pt, err
}
// GetValidPasswordToken returns a valid password token entity for a given user and a given token.
// Since the actual token is not stored in the database for security purposes, all non-expired token entities
// are fetched from the database belonging to the requesting user and a hash of the provided token is compared
// with the hash stored in the database.
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) {
// GetValidPasswordToken returns a valid, non-expired password token entity for a given user, token ID and token.
// Since the actual token is not stored in the database for security purposes, if a matching password token entity is
// found a hash of the provided token is compared with the hash stored in the database in order to validate.
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, userID, tokenID int, token string) (*ent.PasswordToken, error) {
// Ensure expired tokens are never returned
expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration)
// Query to find all tokens for te user that haven't expired
// We need to get all of them in order to properly match the token to the hashes
pts, err := c.orm.PasswordToken.
// Query to find a password token entity that matches the given user and token ID
pt, err := c.orm.PasswordToken.
Query().
Where(passwordtoken.ID(tokenID)).
Where(passwordtoken.HasUserWith(user.ID(userID))).
Where(passwordtoken.CreatedAtGTE(expiration)).
All(ctx.Request().Context())
Only(ctx.Request().Context())
if err != nil {
ctx.Logger().Error(err)
return nil, err
}
// Check all tokens for a hash match
for _, pt := range pts {
switch err.(type) {
case *ent.NotFoundError:
case nil:
// Check the token for a hash match
if err := c.CheckPassword(token, pt.Hash); err == nil {
return pt, nil
}
default:
if !context.IsCanceledError(err) {
return nil, err
}
}
return nil, InvalidPasswordTokenError{}