diff --git a/controllers/login.go b/controllers/login.go index c580f4f..ed50ffa 100644 --- a/controllers/login.go +++ b/controllers/login.go @@ -1,49 +1,85 @@ package controllers import ( + "fmt" + + "goweb/auth" + "goweb/ent" "goweb/ent/user" "goweb/msg" - "golang.org/x/crypto/bcrypt" - "github.com/labstack/echo/v4" ) -type Login struct { - Controller -} +type ( + Login struct { + Controller + form LoginForm + } + + LoginForm struct { + Username string `form:"username" validate:"required"` + Password string `form:"password" validate:"required"` + } +) func (l *Login) Get(c echo.Context) error { p := NewPage(c) p.Layout = "auth" p.Name = "login" p.Title = "Log in" - p.Data = "This is the login page" + p.Data = l.form return l.RenderPage(c, p) } func (l *Login) Post(c echo.Context) error { - name := c.FormValue("username") - pw := c.FormValue("password") - - if name == "" || pw == "" { - msg.Warning(c, "All fields are required.") + // Parse the form values + if err := c.Bind(&l.form); err != nil { + c.Logger().Errorf("unable to parse login form: %v", err) + msg.Danger(c, "An error occurred. Please try again.") return l.Get(c) } + // Validate the form + if err := c.Validate(l.form); err != nil { + msg.Danger(c, "All fields are required.") + return l.Get(c) + } + + // Attempt to load the user u, err := l.Container.ORM.User. Query(). - Where(user.Username(name)). + Where(user.Username(l.form.Username)). First(c.Request().Context()) if err != nil { - c.Logger().Errorf("error querying user during login: %v", err) - } else { - err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(pw)) - if err != nil { + switch err.(type) { + case *ent.NotFoundError: msg.Danger(c, "Invalid credentials. Please try again.") + return l.Get(c) + default: + c.Logger().Errorf("error querying user during login: %v", err) + msg.Danger(c, "An error occurred. Please try again.") + return l.Get(c) } + } - return l.Get(c) + // Check if the password is correct + err = auth.CheckPassword(l.form.Password, u.Password) + if err != nil { + msg.Danger(c, "Invalid credentials. Please try again.") + return l.Get(c) + } + + // Log the user in + err = auth.Login(c, u.ID) + if err != nil { + c.Logger().Errorf("unable to log in user %d: %v", u.ID, err) + msg.Danger(c, "An error occurred. Please try again.") + return l.Get(c) + } + + msg.Success(c, fmt.Sprintf("Welcome back, %s. You are now logged in.", u.Username)) + return l.Redirect(c, "home") } diff --git a/middleware/auth.go b/middleware/auth.go index 76b7e9d..279da5a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -14,14 +14,19 @@ import ( func LoadAuthenticatedUser(orm *ent.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if userID, err := auth.GetUserID(c); err != nil { + if userID, err := auth.GetUserID(c); err == nil { u, err := orm.User.Query(). Where(user.ID(userID)). First(c.Request().Context()) - if err == nil { + switch err.(type) { + case *ent.NotFoundError: + c.Logger().Debug("auth user not found: %d", userID) + case nil: c.Set(context.AuthenticatedUserKey, u) c.Logger().Info("auth user loaded in to context: %d", userID) + default: + c.Logger().Errorf("error querying for authenticated user: %v", err) } } diff --git a/views/pages/login.gohtml b/views/pages/login.gohtml index 118bed6..ce9a643 100644 --- a/views/pages/login.gohtml +++ b/views/pages/login.gohtml @@ -1,15 +1,15 @@ {{define "content"}}