From bd5bbab47c368f689c5421ef1ed850ffc8e21de4 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Thu, 16 Dec 2021 20:58:38 -0500 Subject: [PATCH] Validate if the email address is already in use during registration. --- auth/auth.go | 14 +++++++++----- routes/register.go | 18 ++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 80d9690..3421b53 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,6 +19,7 @@ const ( sessionName = "ua" sessionKeyUserID = "user_id" sessionKeyAuthenticated = "authenticated" + passwordTokenLength = 64 ) type NotAuthenticatedError struct{} @@ -83,7 +84,7 @@ func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { if userID, err := c.GetAuthenticatedUserID(ctx); err == nil { return c.orm.User.Query(). Where(user.ID(userID)). - First(ctx.Request().Context()) + Only(ctx.Request().Context()) } return nil, NotAuthenticatedError{} @@ -103,7 +104,10 @@ func (c *Client) CheckPassword(password, hash string) error { func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) { // Generate the token, which is what will go in the URL, but not the database - token := c.RandomToken(64) + token, err := c.RandomToken(passwordTokenLength) + if err != nil { + return "", nil, err + } // Hash the token, which is what will be stored in the database hash, err := c.HashPassword(token) @@ -148,10 +152,10 @@ func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID in return nil, InvalidTokenError{} } -func (c *Client) RandomToken(length int) string { +func (c *Client) RandomToken(length int) (string, error) { b := make([]byte, length) if _, err := rand.Read(b); err != nil { - return "" + return "", err } - return hex.EncodeToString(b) + return hex.EncodeToString(b), nil } diff --git a/routes/register.go b/routes/register.go index a0282c7..cfa13f8 100644 --- a/routes/register.go +++ b/routes/register.go @@ -3,6 +3,7 @@ package routes import ( "goweb/context" "goweb/controller" + "goweb/ent/user" "goweb/msg" "github.com/labstack/echo/v4" @@ -42,8 +43,6 @@ func (r *Register) Post(c echo.Context) error { return r.Get(c) } - // TODO: Validation for dupe email addresses - // Parse the form values form := new(RegisterForm) if err := c.Bind(form); err != nil { @@ -57,6 +56,20 @@ func (r *Register) Post(c echo.Context) error { return r.Get(c) } + // Check if the email address is taken + exists, err := r.Container.ORM.User. + Query(). + Where(user.Email(form.Email)). + Exist(c.Request().Context()) + + switch { + case err != nil: + return fail("unable to query to see if email is taken", err) + case exists: + msg.Warning(c, "A user with this email address already exists. Please log in.") + return r.Redirect(c, "login") + } + // Hash the password pwHash, err := r.Container.Auth.HashPassword(form.Password) if err != nil { @@ -77,6 +90,7 @@ func (r *Register) Post(c echo.Context) error { c.Logger().Infof("user created: %s", u.Name) + // Log the user in err = r.Container.Auth.Login(c, u.ID) if err != nil { c.Logger().Errorf("unable to log in: %v", err)