From e6a5fa58c7d7efed147b8c8eeb84166ef3a7a093 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Thu, 16 Dec 2021 07:29:16 -0500 Subject: [PATCH] Initial commit of password reset workflow. --- auth/auth.go | 39 ++++++++++++- context/context.go | 1 + middleware/auth.go | 43 ++++++++++---- routes/reset_password.go | 82 +++++++++++++++++++++++++++ routes/router.go | 13 +++-- templates/pages/reset-password.gohtml | 22 +++++++ 6 files changed, 184 insertions(+), 16 deletions(-) create mode 100644 routes/reset_password.go create mode 100644 templates/pages/reset-password.gohtml diff --git a/auth/auth.go b/auth/auth.go index b71db89..343cc81 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -4,9 +4,11 @@ import ( "crypto/rand" "encoding/hex" "errors" + "time" "goweb/config" "goweb/ent" + "goweb/ent/passwordtoken" "goweb/ent/user" "github.com/labstack/echo-contrib/session" @@ -20,6 +22,13 @@ const ( sessionKeyAuthenticated = "authenticated" ) +type NotAuthenticatedError struct{} + +// Error implements the error interface. +func (e NotAuthenticatedError) Error() string { + return "user not authenticated" +} + type Client struct { config *config.Config orm *ent.Client @@ -61,7 +70,7 @@ func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { return sess.Values[sessionKeyUserID].(int), nil } - return 0, errors.New("user not authenticated") + return 0, NotAuthenticatedError{} } func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { @@ -71,7 +80,7 @@ func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { First(ctx.Request().Context()) } - return nil, errors.New("user not authenticated") + return nil, NotAuthenticatedError{} } func (c *Client) HashPassword(password string) (string, error) { @@ -106,6 +115,32 @@ func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (strin return token, pt, err } +func (c *Client) GetValidPasswordToken(ctx echo.Context, token string) (*ent.PasswordToken, error) { + // Hash the token in order to match in the database + hash, err := c.HashPassword(token) + if err != nil { + return nil, err + } + + // Query to find a matching token + pt, err := c.orm.PasswordToken. + Query(). + Where(passwordtoken.Hash(hash)). + First(ctx.Request().Context()) + + if err != nil { + ctx.Logger().Error(err) + return nil, err + } + + // Check if the token is no longer valid + if pt.CreatedAt.Before(time.Now().Add(-c.config.App.PasswordTokenExpiration)) { + return nil, errors.New("token has expired") + } + + return pt, nil +} + func (c *Client) RandomToken(length int) string { b := make([]byte, length) if _, err := rand.Read(b); err != nil { diff --git a/context/context.go b/context/context.go index fe1d1bf..b1c5290 100644 --- a/context/context.go +++ b/context/context.go @@ -3,4 +3,5 @@ package context const ( AuthenticatedUserKey = "auth_user" FormKey = "form" + PasswordTokenKey = "password_token" ) diff --git a/middleware/auth.go b/middleware/auth.go index ff02f27..32cd5eb 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -6,6 +6,7 @@ import ( "goweb/auth" "goweb/context" "goweb/ent" + "goweb/msg" "github.com/labstack/echo/v4" ) @@ -13,16 +14,16 @@ import ( func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if user, err := authClient.GetAuthenticatedUser(c); err == nil { - switch err.(type) { - case *ent.NotFoundError: - c.Logger().Debug("auth user not found") - case nil: - c.Set(context.AuthenticatedUserKey, user) - c.Logger().Info("auth user loaded in to context: %d", user.ID) - default: - c.Logger().Errorf("error querying for authenticated user: %v", err) - } + u, err := authClient.GetAuthenticatedUser(c) + switch err.(type) { + case *ent.NotFoundError: + c.Logger().Debug("auth user not found") + case auth.NotAuthenticatedError: + case nil: + c.Set(context.AuthenticatedUserKey, u) + c.Logger().Info("auth user loaded in to context: %d", u.ID) + default: + c.Logger().Errorf("error querying for authenticated user: %v", err) } return next(c) @@ -30,6 +31,28 @@ func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc { } } +func LoadValidPasswordToken(authClient *auth.Client) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + tokenParam := c.Param("password_token") + if tokenParam == "" { + c.Logger().Warn("missing password token path parameter") + return echo.NewHTTPError(http.StatusNotFound, "Not found") + } + + token, err := authClient.GetValidPasswordToken(c, tokenParam) + if err != nil { + msg.Warning(c, "The link is either invalid or has expired. Please request a new one.") + return c.Redirect(http.StatusFound, c.Echo().Reverse("forgot_password")) + } + + c.Set(context.PasswordTokenKey, token) + + return next(c) + } + } +} + func RequireAuthentication() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { diff --git a/routes/reset_password.go b/routes/reset_password.go new file mode 100644 index 0000000..cee3df5 --- /dev/null +++ b/routes/reset_password.go @@ -0,0 +1,82 @@ +package routes + +import ( + "goweb/controller" + "goweb/msg" + + "github.com/labstack/echo/v4" +) + +type ( + ResetPassword struct { + controller.Controller + } + + ResetPasswordForm struct { + Password string `form:"password" validate:"required" label:"Password"` + ConfirmPassword string `form:"password-confirm" validate:"required,eqfield=Password" label:"Confirm password"` + } +) + +func (r *ResetPassword) Get(c echo.Context) error { + p := controller.NewPage(c) + p.Layout = "auth" + p.Name = "reset-password" + p.Title = "Reset password" + return r.RenderPage(c, p) +} + +func (r *ResetPassword) Post(c echo.Context) error { + fail := func(message string, err error) error { + c.Logger().Errorf("%s: %v", message, err) + msg.Danger(c, "An error occurred. Please try again.") + return r.Get(c) + } + + succeed := func() error { + msg.Success(c, "Your password has been updated.") + return r.Redirect(c, "login") + } + + // Parse the form values + form := new(ResetPassword) + if err := c.Bind(form); err != nil { + return fail("unable to parse forgot password form", err) + } + + // Validate the form + if err := c.Validate(form); err != nil { + r.SetValidationErrorMessages(c, err, form) + return r.Get(c) + } + + // Attempt to load the user + //u, err := f.Container.ORM.User. + // Query(). + // Where(user.Email(form.Email)). + // First(c.Request().Context()) + // + //if err != nil { + // switch err.(type) { + // case *ent.NotFoundError: + // return succeed() + // default: + // return fail("error querying user during forgot password", err) + // } + //} + // + //// Generate the token + //token, _, err := f.Container.Auth.GeneratePasswordResetToken(c, u.ID) + //if err != nil { + // return fail("error generating password reset token", err) + //} + //c.Logger().Infof("generated password reset token for user %d", u.ID) + // + //// Email the user + //err = f.Container.Mail.Send(c, u.Email, fmt.Sprintf("Go here to reset your password: %s", token)) // TODO: route + //if err != nil { + // return fail("error sending password reset email", err) + //} + + return succeed() +} diff --git a/routes/router.go b/routes/router.go index ff85605..7bc5aea 100644 --- a/routes/router.go +++ b/routes/router.go @@ -71,11 +71,11 @@ func BuildRouter(c *container.Container) { c.Web.Validator = &Validator{validator: validator.New()} // Routes - navRoutes(g, ctr) - userRoutes(g, ctr) + navRoutes(c, g, ctr) + userRoutes(c, g, ctr) } -func navRoutes(g *echo.Group, ctr controller.Controller) { +func navRoutes(c *container.Container, g *echo.Group, ctr controller.Controller) { home := Home{Controller: ctr} g.GET("/", home.Get).Name = "home" @@ -87,7 +87,7 @@ func navRoutes(g *echo.Group, ctr controller.Controller) { g.POST("/contact", contact.Post).Name = "contact.post" } -func userRoutes(g *echo.Group, ctr controller.Controller) { +func userRoutes(c *container.Container, g *echo.Group, ctr controller.Controller) { logout := Logout{Controller: ctr} g.GET("/logout", logout.Get, middleware.RequireAuthentication()).Name = "logout" @@ -103,4 +103,9 @@ func userRoutes(g *echo.Group, ctr controller.Controller) { forgot := ForgotPassword{Controller: ctr} noAuth.GET("/password", forgot.Get).Name = "forgot_password" noAuth.POST("/password", forgot.Post).Name = "forgot_password.post" + + resetGroup := noAuth.Group("/password/reset", middleware.LoadValidPasswordToken(c.Auth)) + reset := ResetPassword{Controller: ctr} + resetGroup.GET("/token/:password_token", reset.Get).Name = "reset_password" + resetGroup.POST("/token/:password_token", reset.Post).Name = "reset_password.post" } diff --git a/templates/pages/reset-password.gohtml b/templates/pages/reset-password.gohtml new file mode 100644 index 0000000..eabe506 --- /dev/null +++ b/templates/pages/reset-password.gohtml @@ -0,0 +1,22 @@ +{{define "content"}} +
+
+ +
+ +
+
+
+ +
+ +
+
+
+

+ +

+
+ {{template "csrf" .}} +
+{{end}} \ No newline at end of file