diff --git a/context/context.go b/context/context.go new file mode 100644 index 0000000..b14e422 --- /dev/null +++ b/context/context.go @@ -0,0 +1,5 @@ +package context + +const ( + AuthenticatedUserKey = "auth_user" +) diff --git a/controllers/home.go b/controllers/home.go index f41dade..fc64c9d 100644 --- a/controllers/home.go +++ b/controllers/home.go @@ -1,8 +1,6 @@ package controllers import ( - "goweb/auth" - "github.com/labstack/echo/v4" ) @@ -18,8 +16,5 @@ func (h *Home) Get(c echo.Context) error { p.Metatags.Description = "Welcome to the homepage." p.Metatags.Keywords = []string{"Go", "MVC", "Web", "Software"} - uid, _ := auth.GetUserID(c) - c.Logger().Infof("logged in user ID: %d", uid) - return h.RenderPage(c, p) } diff --git a/controllers/page.go b/controllers/page.go index 891542b..cf4dfb2 100644 --- a/controllers/page.go +++ b/controllers/page.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "goweb/auth" "goweb/msg" "goweb/pager" @@ -61,6 +62,10 @@ func NewPage(c echo.Context) Page { p.CSRF = csrf.(string) } + if _, err := auth.GetUserID(c); err == nil { + p.IsAuth = true + } + return p } diff --git a/controllers/router.go b/controllers/router.go index 1e0015c..ac4ea3f 100644 --- a/controllers/router.go +++ b/controllers/router.go @@ -56,6 +56,7 @@ func BuildRouter(c *container.Container) { echomw.CSRFWithConfig(echomw.CSRFConfig{ TokenLookup: "form:csrf", }), + middleware.LoadAuthenticatedUser(c.ORM), ) // Base controller @@ -86,14 +87,15 @@ func navRoutes(g *echo.Group, ctr Controller) { } func userRoutes(g *echo.Group, ctr Controller) { - login := Login{Controller: ctr} - g.GET("/user/login", login.Get).Name = "login" - g.POST("/user/login", login.Post).Name = "login.post" - logout := Logout{Controller: ctr} - g.GET("/user/logout", logout.Get).Name = "logout" + g.GET("/logout", logout.Get, middleware.RequireAuthentication()).Name = "logout" + + noAuth := g.Group("/user", middleware.RequireNoAuthentication()) + login := Login{Controller: ctr} + noAuth.GET("/login", login.Get).Name = "login" + noAuth.POST("/login", login.Post).Name = "login.post" register := Register{Controller: ctr} - g.GET("/user/register", register.Get).Name = "register" - g.POST("/user/register", register.Post).Name = "register.post" + noAuth.GET("/register", register.Get).Name = "register" + noAuth.POST("/register", register.Post).Name = "register.post" } diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..76b7e9d --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "net/http" + + "goweb/auth" + "goweb/context" + "goweb/ent" + "goweb/ent/user" + + "github.com/labstack/echo/v4" +) + +func LoadAuthenticatedUser(orm *ent.Client) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if userID, err := auth.GetUserID(c); err != nil { + u, err := orm.User.Query(). + Where(user.ID(userID)). + First(c.Request().Context()) + + if err == nil { + c.Set(context.AuthenticatedUserKey, u) + c.Logger().Info("auth user loaded in to context: %d", userID) + } + } + + return next(c) + } + } +} + +func RequireAuthentication() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if u := c.Get(context.AuthenticatedUserKey); u == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + return next(c) + } + } +} + +func RequireNoAuthentication() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if u := c.Get(context.AuthenticatedUserKey); u != nil { + return echo.NewHTTPError(http.StatusForbidden, "Forbidden") + } + + return next(c) + } + } +} diff --git a/views/layouts/main.gohtml b/views/layouts/main.gohtml index 73345bc..6b06e88 100644 --- a/views/layouts/main.gohtml +++ b/views/layouts/main.gohtml @@ -17,7 +17,7 @@ {{link (call .Reverse "about") "About" .Path "navbar-item"}} {{link (call .Reverse "contact") "Contact" .Path "navbar-item"}} {{- if .IsAuth}} - + {{link (call .Reverse "logout") "Logout" .Path "navbar-item"}} {{- else}} {{link (call .Reverse "login") "Login" .Path "navbar-item"}} {{- end}}