diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..73d88a9 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,60 @@ +package auth + +import ( + "errors" + + "github.com/labstack/echo-contrib/session" + "github.com/labstack/echo/v4" + "golang.org/x/crypto/bcrypt" +) + +const ( + sessionName = "ua" + sessionKeyUserID = "user_id" + sessionKeyAuthenticated = "authenticated" +) + +func Login(c echo.Context, userID int) error { + sess, err := session.Get(sessionName, c) + if err != nil { + return err + } + sess.Values[sessionKeyUserID] = userID + sess.Values[sessionKeyAuthenticated] = true + // TODO: max age? + return sess.Save(c.Request(), c.Response()) +} + +func Logout(c echo.Context) error { + sess, err := session.Get(sessionName, c) + if err != nil { + return err + } + sess.Values[sessionKeyAuthenticated] = false + return sess.Save(c.Request(), c.Response()) +} + +func GetUserID(c echo.Context) (int, error) { + sess, err := session.Get(sessionName, c) + if err != nil { + return 0, err + } + + if sess.Values[sessionKeyAuthenticated] == true { + return sess.Values[sessionKeyUserID].(int), nil + } + + return 0, errors.New("user not authenticated") +} + +func HashPassword(password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return string(hash), nil +} + +func CheckPassword(password, hash string) error { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) +} diff --git a/controllers/home.go b/controllers/home.go index fc64c9d..f41dade 100644 --- a/controllers/home.go +++ b/controllers/home.go @@ -1,6 +1,8 @@ package controllers import ( + "goweb/auth" + "github.com/labstack/echo/v4" ) @@ -16,5 +18,8 @@ func (h *Home) Get(c echo.Context) error { p.Metatags.Description = "Welcome to the homepage." p.Metatags.Keywords = []string{"Go", "MVC", "Web", "Software"} + uid, _ := auth.GetUserID(c) + c.Logger().Infof("logged in user ID: %d", uid) + return h.RenderPage(c, p) } diff --git a/controllers/logout.go b/controllers/logout.go new file mode 100644 index 0000000..82b6bf6 --- /dev/null +++ b/controllers/logout.go @@ -0,0 +1,19 @@ +package controllers + +import ( + "goweb/auth" + "goweb/msg" + + "github.com/labstack/echo/v4" +) + +type Logout struct { + Controller +} + +func (l *Logout) Get(c echo.Context) error { + if err := auth.Logout(c); err == nil { + msg.Success(c, "You have been logged out successfully.") + } + return l.Redirect(c, "home") +} diff --git a/controllers/register.go b/controllers/register.go index 4fc7a38..f0c19d3 100644 --- a/controllers/register.go +++ b/controllers/register.go @@ -1,10 +1,9 @@ package controllers import ( + "goweb/auth" "goweb/msg" - "golang.org/x/crypto/bcrypt" - "github.com/labstack/echo/v4" ) @@ -38,20 +37,18 @@ func (r *Register) Post(c echo.Context) error { } // Parse the form values - form := new(RegisterForm) - if err := c.Bind(form); err != nil { + if err := c.Bind(&r.form); err != nil { return fail("unable to parse form values", err) } - r.form = *form // Validate the form - if err := c.Validate(form); err != nil { + if err := c.Validate(r.form); err != nil { msg.Danger(c, "All fields are required.") return r.Get(c) } // Hash the password - pwHash, err := bcrypt.GenerateFromPassword([]byte(form.Password), bcrypt.DefaultCost) + pwHash, err := auth.HashPassword(r.form.Password) if err != nil { return fail("unable to hash password", err) } @@ -59,8 +56,8 @@ func (r *Register) Post(c echo.Context) error { // Attempt creating the user u, err := r.Container.ORM.User. Create(). - SetUsername(form.Username). - SetPassword(string(pwHash)). + SetUsername(r.form.Username). + SetPassword(pwHash). Save(c.Request().Context()) if err != nil { @@ -68,6 +65,12 @@ func (r *Register) Post(c echo.Context) error { } c.Logger().Infof("user created: %s", u.Username) + + err = auth.Login(c, u.ID) + if err != nil { + // TODO + } + msg.Info(c, "Your account has been created. You are now logged in.") return r.Redirect(c, "home") } diff --git a/controllers/router.go b/controllers/router.go index 730c163..1e0015c 100644 --- a/controllers/router.go +++ b/controllers/router.go @@ -90,6 +90,9 @@ func userRoutes(g *echo.Group, ctr Controller) { g.GET("/user/login", login.Get).Name = "login" g.POST("/user/login", login.Post).Name = "login.post" + logout := Logout{Controller: ctr} + g.GET("/user/logout", logout.Get).Name = "logout" + register := Register{Controller: ctr} g.GET("/user/register", register.Get).Name = "register" g.POST("/user/register", register.Post).Name = "register.post"